Skip to content

Commit

Permalink
docs: add timm backbone tutorial (#1631)
Browse files Browse the repository at this point in the history
* docs: add timm backbone tutorial

* docs: update timm tutorial

* docs: update tutorial

Co-authored-by: guarin <[email protected]>

---------

Co-authored-by: guarin <[email protected]>
  • Loading branch information
SauravMaheshkar and guarin authored Aug 14, 2024
1 parent 48ac001 commit 222d97b
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Want to jump to the tutorials and see Lightly in action?
- [Use Lightly with Custom Augmentations](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_custom_augmentations.html)
- [Pre-train a Detectron2 Backbone with Lightly](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_pretrain_detectron2.html)
- [Finetuning Lightly Checkpoints](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_checkpoint_finetuning.html)
- [Using timm Models as Backbones](https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_timm_backbone.html)

Community and partner projects:

Expand Down Expand Up @@ -133,7 +134,7 @@ class SimCLR(torch.nn.Module):
return z


# Use a resnet backbone.
# Use a resnet backbone from torchvision.
backbone = torchvision.models.resnet18()
# Ignore the classification head as we only want the features.
backbone.fc = torch.nn.Identity()
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Lightly AI
tutorials/package/tutorial_custom_augmentations.rst
tutorials/package/tutorial_pretrain_detectron2.rst
tutorials/package/tutorial_checkpoint_finetuning.rst
tutorials/package/tutorial_timm_backbone.rst

.. toctree::
:maxdepth: 1
Expand Down
85 changes: 85 additions & 0 deletions docs/source/tutorials_source/package/tutorial_timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
.. _lightly-timm-backbone-tutorial-8:
Tutorial 8: Using timm Models as Backbones
===========================================
You can use LightlySSL to pre-train any timm model using self-supervised learning since
most methods share a similar workflow. All methods have a base model (the backbone), which
can be any fundamental model such as ResNet or MobileNet.
In this tutorial, we will learn how to use a model architecture from the timm library
as a backbone in a self-supervised learning workflow.
"""

# %%
# Imports
# -------
#
# Import the Python frameworks we need for this tutorial.
# Make sure you have the necessary packages installed.
#
# .. code-block:: console
#
# pip install lightly"[timm]"


import timm
import torch
import torch.nn as nn

# %%
# timm comes packaged with >700 pre-trained models designed to be flexible and easy to use.
# These pre-trained models can be loaded using the
# `create_model() <https://huggingface.co/docs/timm/v1.0.8/en/reference/models#timm.create_model>`_
# function. For example, we can use the following snippet to create an efficient model.

backbone = timm.create_model(model_name="efficientnet_b0")


# %%
# Using a timm Model as a Backbone
# ---------------------------------
#
# We can now use this model as a backbone for training. Let's see how we can
# create a torch module for the `SimCLR <https://arxiv.org/abs/2002.05709>`_ method.

from lightly.models.modules.heads import SimCLRProjectionHead


class SimCLR(torch.nn.Module):
def __init__(self, backbone, feature_dim, out_dim):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(feature_dim, feature_dim, out_dim)

def forward(self, x):
features = self.backbone.forward_features(x)
h = self.backbone.global_pool(features).flatten(start_dim=1)
z = self.projection_head(h)
return z


simclr = SimCLR(backbone, feature_dim=1280, out_dim=128)

# check if it works
input_a = torch.randn((2, 3, 224, 224))
input_b = torch.randn((2, 3, 224, 224))
out_a = simclr(input_a)
out_b = simclr(input_b)

# %%
# Next Steps
# ------------
#
# For an indepth tutorial on fine-tuning a model using `SimCLR <https://arxiv.org/abs/2002.05709>`_
# you can refer to our fine-tuning :ref:`lightly-checkpoint-finetuning-tutorial-7`.
# Interested in pre-training your own self-supervised models? Check out our other
# tutorials:
#
# - :ref:`input-structure-label`
# - :ref:`lightly-moco-tutorial-2`
# - :ref:`lightly-simsiam-tutorial-4`
# - :ref:`lightly-custom-augmentation-5`
# - :ref:`lightly-detectron-tutorial-6`
#

0 comments on commit 222d97b

Please sign in to comment.