-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] ValueError: [scatter] Cannot calculate VJP with respect to indices. #1439
Comments
It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency. Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime? |
Hi @barronalex : Yes, this is for the implementation of DINO. Seems like zeroing out the gradient should solve the problem for the simple greedy matcher. Not sure if it would solve the problem for the Hungarian Matcher though since it uses scipy's linear_sum_assignment. This is particularly a problem as it requires evaluating the cost matrix which in turn forbids me from calling mx.compile on the value_and_grad function for training. Please correct me if I'm mistaken here. |
Yes, you would need an MLX implementation of |
Thanks @barronalex , understood. Refering from this jax implementation, there is another challenge in writing an MLX implementation in python that can be compiled: |
Is We could definitely implement something like |
Thanks for your response @barronalex ! I realise that HungarianMatcher is not a big performance bottleneck. However, I'm facing memory inflation while in the MLX port of prepare_for_cdn function. I thought that maybe if I compiled the entire computation graph, that would solve the problem. |
Describe the bug
I understand that HungarianMatching algorithm requires linear_sum_assignment from scipy, which needs cost matrix to be evaluated. Hence, I cannot compile my train step function. However, if I use SimpleMatching algorithm and then compile my train step, I get the following error:
Code for matching is as follows:
Also, is there a work around to get mx.compile working for Hungarian Matching algorithm?
Will greatly appreciate your help to solve this.
Additional context
Using MLX 0.17.3
The text was updated successfully, but these errors were encountered: