-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bad performance on large vision models #22
Comments
Dear Elliot, Thanks a lot for the interest in our optimizer and for sharing your observation with us. YellowFin aims at the providing good results without much tuning. However, as seen in the SGD curve, the learning rate drop schedule is very important to improve accuracy. If you already know a good hand tuned learning rate / momentum schedule, it can typically help you do even better. Still, we would like to study your specific case, we believe that there is something interesting here. If you want to further fine tune to improve YF performance, I would suggest considering the learning rate factor we provided in the implementation. The learning rate factor can 1) enforce a multiplier to the learning rate given by YF throughout 2) implement learning rate decay / learning rate drop schedule. Here is an example (https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/pytorch-cifar/main.py#L161) you can refer to. We also have some examples where TF, combined with learning rate drop/decay can further boost the performance in Appendix J4 in our new arxiv manuscript. We would love to study this specific case, please let me know if you can share the code and data with us to play with. Cheers, |
Thank you for the quick response Jian! As I stated above, both of these training runs have the same learning rate drop schedule applied: The code I'm using to apply my lambda callback is given below: def lambda_lr_callback(lr_lambda, verbose=True):
"""
Callback to scale initial learning rate by the result of the given function
every epoch, e.g. if `lr_lambda` returns `1.0` for an epoch, the learning
rate is unchanged, if it returns `0.1` the learning rate will be 1/10th of
its typical value.
"""
from torch.optim.lr_scheduler import LambdaLR
from ..utils import log
scheduler = None
def lambda_lr_cb(data):
nonlocal scheduler
optimizer = data['optimizer']
epoch = data['model'].epoch
lr_to_print = None
# If this is a YF-style optimizer, use set_lr_factor():
if hasattr(optimizer, 'set_lr_factor'):
optimizer.set_lr_factor(lr_lambda(epoch))
# YF-style, we're just going to print _lr and hope for the best.
lr_to_print = optimizer._lr
else:
# Otherwise, use a LambdaLR scheduler, and just tell it to step()
if scheduler is None:
scheduler = LambdaLR(data['optimizer'], lr_lambda=lr_lambda)
scheduler.step(epoch=epoch)
lr_to_print = scheduler.get_lr()[0]
# Print out the learning rate for fun
if verbose:
log("Optimizer learning rate: %.2e"%(lr_to_print))
return True
return lambda_lr_cb This code is used to create a callback that is applied to the model/optimizer at the beginning of every epoch. The learning rate we are logging (And which I have plotted above) is the result of the |
I am training on ImageNet, it is quite a lot of data. If you train any large image classification models you will probably train on ImageNet. My code is not special, it is very similar to this repository, but I have added a lot of other machinery to allow for things like callbacks, regularization, model loading and saving, multi-gpu training, etc.... If you would like to look at my code, I will add you to a private repository, but the data we train on is so large and is in a special format that you will not be able to run it locally without me providing it to you (it is over 100GB, so while I could find a way to get it to you, it will not be easy) |
Hello there,
I am doing my best to learn how to use this optimizer, as I would very much like to have an auto-tuned optimizer where I do not have to spend endless days fiddling with hyperparameters. I have tried to use YellowFin to learn large vision models such as MobileNet, but my results are always very disappointing as compared to a traditional optimizer such as
SGD
. I am not so concerned about convergence time as I am about loss/accuracy; I have found that YellowFin tends to converge to a much worse loss/accuracy than my SGD runs do.I am posting here an example of training MobileNet on the ImageNet dataset with a batch size of 64, comparing the training and testing loss (as well as testing accuracy) of a few epochs of training on MobileNet. In both cases, I have a learning rate schedule applied to set the learning rate factor to
0.3 ^ (epoch // 10)
, which causes the learning rate to fall to 3/10 of its value every 10 epochs. You can see the effect of this learning rate schedule in thesgd
plot fairly easily, theyf
plot shows it less clearly. In these figures, the training loss (per minibatch) is shown in blue, while the testing loss (per epoch) is shown in red, with the relevant axis shown on the left. The top-1 and top-5 accuracies on the training dataset are shown in green (per epoch), with their relevant axis given on the right. Other than the optimizer choice, all other training settings are the same, including minibatch size (64), dataset (ImageNet) and model architecture (MobileNet).Here is a plot for an SGD optimizer run (note that I have this model only partially trained, this is because it has trained enough that we can already see it will converge to a significantly better loss than the YF model did, below):
Here is a plot for a YellowFin optimizer run:
If there are any questions about my methodology I would be happy to explain in greater detail. There is nothing particularly special going on in my model, I am simply trying to determine why YellowFin seems to converge with such poor results.
The text was updated successfully, but these errors were encountered: