This repository contains the Pytorch 🔥 implementation of an image captioning model that uses attention. Demo.
To try it, run the following commands :
- Install the necessary python packages :
pip install -r requirements.txt
-
Change the DATA_PATH, caption_file & images_directory paths in the data.py file.
-
Train the model with :
python train.py
- Then open a terminal and run it :
python app.py
The app should be usable on localhost in the browser.
- The deep learning model has been training using the Flickr8k image captions dataset.
The model contains 3 main components:
- Encoder to extract features with the pre-trained ResNet50 model (trained on the imagenet dataset).
- An Attention Mechanism implementation so that the neural network knows on which part of the input image to focus on when when decoding certain words.
- An LSTM decoder to generate captions adn return the attentions alphas along with it.
The model has been trained for 25 epochs and took 2 hours and a 30 minutes to learn. The performance can be increased with more data, more evolved neural network and more iterations.
We can make a prediction on the following example image to get the corresponding caption :
Then we can visualize the attention values on different spots of the image according to the different word tokens generated.