diff --git a/.github/workflows/check_example_nbs.yml b/.github/workflows/check_example_nbs.yml new file mode 100644 index 000000000..2792b2e4f --- /dev/null +++ b/.github/workflows/check_example_nbs.yml @@ -0,0 +1,41 @@ +name: "Check Example Notebooks" + +on: + push: + branches: [master] + paths: + - "examples/**/*.py" + pull_request: + branches: [master] + paths: + - "examples/**/*.py" + workflow_dispatch: + +jobs: + convert-to-nbs: + name: "Check Example Notebooks" + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v4 + - name: Set Up Python + uses: actions/setup-python@v5 + with: + python-version: 3.7 + - name: Set Up Environment + run: | + make install-uv reset-venv + source .venv/bin/activate + make install-pinned-extras + - name: Convert using Script + run: | + source .venv/bin/activate + make generate-example-notebooks + - name: Check for diff + run: | + source .venv/bin/activate + git add examples/notebooks/ + if ! git diff --cached --exit-code; then + echo "Notebooks have changed! Please run `make generate-example-notebooks` and commit the changes." + exit 1 + fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9fe22151b..58361f4b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,30 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.6.0 hooks: - id: detect-private-key # check for private keys - id: check-added-large-files # prevent commit of files >500kB args: ['--maxkb=500'] - repo: https://github.com/psf/black - rev: 23.1.0 # aligned with the version defined in pyproject.toml + rev: 24.8.0 # aligned with the version defined in pyproject.toml hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.11.5 # aligned with the version defined in pyproject.toml + rev: 5.13.2 # aligned with the version defined in pyproject.toml hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 # aligned with the version defined in pyproject.toml + rev: v1.11.1 # aligned with the version defined in pyproject.toml hooks: - id: mypy additional_dependencies: - 'numpy' +- repo: https://github.com/nbQA-dev/nbQA + rev: 1.8.7 + hooks: + - id: nbqa-black + - id: nbqa-isort + args: ["--profile=black"] - repo: local hooks: - id: pytest-check # run all tests diff --git a/Makefile b/Makefile index 263393f11..c7cfeb2b9 100644 --- a/Makefile +++ b/Makefile @@ -185,3 +185,11 @@ install-pinned-extras: .PHONY: install-latest install-latest: uv pip install --upgrade --reinstall ${EDITABLE} . --all-extras --requirement pyproject.toml + + +# Generate Notebooks from examples +.PHONY: generate-example-notebooks +generate-example-notebooks: + python examples/create_example_nbs.py examples/pytorch examples/notebooks/pytorch + python examples/create_example_nbs.py examples/pytorch_lightning examples/notebooks/pytorch_lightning + python examples/create_example_nbs.py examples/pytorch_lightning_distributed examples/notebooks/pytorch_lightning_distributed diff --git a/README.md b/README.md index e78d075b1..6f0f4149d 100644 --- a/README.md +++ b/README.md @@ -39,25 +39,25 @@ and PyTorch Lightning distributed examples for all models to kickstart your proj **Models**: -- AIM, 2024 [paper](https://arxiv.org/abs/2401.08541) [docs](https://docs.lightly.ai/self-supervised-learning/examples/aim.html) -- Barlow Twins, 2021 [paper](https://arxiv.org/abs/2103.03230) [docs](https://docs.lightly.ai/self-supervised-learning/examples/barlowtwins.html) -- BYOL, 2020 [paper](https://arxiv.org/abs/2006.07733) [docs](https://docs.lightly.ai/self-supervised-learning/examples/byol.html) -- DCL & DCLW, 2021 [paper](https://arxiv.org/abs/2110.06848) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dcl.html) -- DenseCL, 2021 [paper](https://arxiv.org/abs/2011.09157) [docs](https://docs.lightly.ai/self-supervised-learning/examples/densecl.html) -- DINO, 2021 [paper](https://arxiv.org/abs/2104.14294) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dino.html) -- MAE, 2021 [paper](https://arxiv.org/abs/2111.06377) [docs](https://docs.lightly.ai/self-supervised-learning/examples/mae.html) -- MSN, 2022 [paper](https://arxiv.org/abs/2204.07141) [docs](https://docs.lightly.ai/self-supervised-learning/examples/msn.html) -- MoCo, 2019 [paper](https://arxiv.org/abs/1911.05722) [docs](https://docs.lightly.ai/self-supervised-learning/examples/moco.html) -- NNCLR, 2021 [paper](https://arxiv.org/abs/2104.14548) [docs](https://docs.lightly.ai/self-supervised-learning/examples/nnclr.html) -- PMSN, 2022 [paper](https://arxiv.org/abs/2210.07277) [docs](https://docs.lightly.ai/self-supervised-learning/examples/pmsn.html) -- SimCLR, 2020 [paper](https://arxiv.org/abs/2002.05709) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simclr.html) -- SimMIM, 2021 [paper](https://arxiv.org/abs/2111.09886) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simmim.html) -- SimSiam, 2021 [paper](https://arxiv.org/abs/2011.10566) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simsiam.html) -- SMoG, 2022 [paper](https://arxiv.org/abs/2207.06167) [docs](https://docs.lightly.ai/self-supervised-learning/examples/smog.html) -- SwaV, 2020 [paper](https://arxiv.org/abs/2006.09882) [docs](https://docs.lightly.ai/self-supervised-learning/examples/swav.html) -- TiCo, 2022 [paper](https://arxiv.org/abs/2206.10698) [docs](https://docs.lightly.ai/self-supervised-learning/examples/tico.html) -- VICReg, 2022 [paper](https://arxiv.org/abs/2105.04906) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicreg.html) -- VICRegL, 2022 [paper](https://arxiv.org/abs/2210.01571) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicregl.html) +- AIM, 2024 [paper](https://arxiv.org/abs/2401.08541) [docs](https://docs.lightly.ai/self-supervised-learning/examples/aim.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/aim.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/aim.ipynb) +- Barlow Twins, 2021 [paper](https://arxiv.org/abs/2103.03230) [docs](https://docs.lightly.ai/self-supervised-learning/examples/barlowtwins.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/barlowtwins.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/barlowtwins.ipynb) +- BYOL, 2020 [paper](https://arxiv.org/abs/2006.07733) [docs](https://docs.lightly.ai/self-supervised-learning/examples/byol.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/byol.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/byol.ipynb) +- DCL & DCLW, 2021 [paper](https://arxiv.org/abs/2110.06848) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dcl.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/dcl.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/dcl.ipynb) +- DenseCL, 2021 [paper](https://arxiv.org/abs/2011.09157) [docs](https://docs.lightly.ai/self-supervised-learning/examples/densecl.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/densecl.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/densecl.ipynb) +- DINO, 2021 [paper](https://arxiv.org/abs/2104.14294) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dino.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/dino.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/dino.ipynb) +- MAE, 2021 [paper](https://arxiv.org/abs/2111.06377) [docs](https://docs.lightly.ai/self-supervised-learning/examples/mae.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/mae.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/mae.ipynb) +- MSN, 2022 [paper](https://arxiv.org/abs/2204.07141) [docs](https://docs.lightly.ai/self-supervised-learning/examples/msn.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/msn.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/msn.ipynb) +- MoCo, 2019 [paper](https://arxiv.org/abs/1911.05722) [docs](https://docs.lightly.ai/self-supervised-learning/examples/moco.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/moco.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/moco.ipynb) +- NNCLR, 2021 [paper](https://arxiv.org/abs/2104.14548) [docs](https://docs.lightly.ai/self-supervised-learning/examples/nnclr.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/nnclr.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/nnclr.ipynb) +- PMSN, 2022 [paper](https://arxiv.org/abs/2210.07277) [docs](https://docs.lightly.ai/self-supervised-learning/examples/pmsn.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/pmsn.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/pmsn.ipynb) +- SimCLR, 2020 [paper](https://arxiv.org/abs/2002.05709) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simclr.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/simclr.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/simclr.ipynb) +- SimMIM, 2021 [paper](https://arxiv.org/abs/2111.09886) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simmim.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/simmim.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/simmim.ipynb) +- SimSiam, 2021 [paper](https://arxiv.org/abs/2011.10566) [docs](https://docs.lightly.ai/self-supervised-learning/examples/simsiam.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/simsiam.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/simsiam.ipynb) +- SMoG, 2022 [paper](https://arxiv.org/abs/2207.06167) [docs](https://docs.lightly.ai/self-supervised-learning/examples/smog.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/smog.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/smog.ipynb) +- SwaV, 2020 [paper](https://arxiv.org/abs/2006.09882) [docs](https://docs.lightly.ai/self-supervised-learning/examples/swav.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/swav.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/swav.ipynb) +- TiCo, 2022 [paper](https://arxiv.org/abs/2206.10698) [docs](https://docs.lightly.ai/self-supervised-learning/examples/tico.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/tico.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/tico.ipynb) +- VICReg, 2022 [paper](https://arxiv.org/abs/2105.04906) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicreg.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/vicreg.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/vicreg.ipynb) +- VICRegL, 2022 [paper](https://arxiv.org/abs/2210.01571) [docs](https://docs.lightly.ai/self-supervised-learning/examples/vicregl.html) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch/vicregl.ipynb) [![Open In Colab](https://img.shields.io/badge/Colab-PyTorch_Lightning-blue?logo=googlecolab)](https://colab.research.google.com/github/lightly-ai/lightly/blob/master/examples/notebooks/pytorch_lightning/vicregl.ipynb) ## Tutorials diff --git a/examples/README.md b/examples/README.md index d48e77c37..6c112ae21 100644 --- a/examples/README.md +++ b/examples/README.md @@ -19,3 +19,7 @@ The examples should also run on [Google Colab](https://colab.research.google.com You can find additional information for each model in our [Documentation](https://docs.lightly.ai//examples/models.html#) + +> [!IMPORTANT] +> The examples notebooks are generated using the [`create_example_nbs.py`](./create_example_nbs.py) script and should not be modified manually. All changes must be made to the respecitve example. + diff --git a/examples/create_example_nbs.py b/examples/create_example_nbs.py new file mode 100644 index 000000000..385b97339 --- /dev/null +++ b/examples/create_example_nbs.py @@ -0,0 +1,52 @@ +import argparse +from pathlib import Path + +import jupytext +import nbformat +from nbformat import NotebookNode + + +def add_installation_cell(nb: NotebookNode, script_path: Path) -> NotebookNode: + # Find installation snippet + with open(script_path, "r") as f: + for line in f: + line = line.strip() + if line.startswith("# pip install"): + pip_command = "!" + line.lstrip("# ").strip() + # Create a new code cell + code_cell = nbformat.v4.new_code_cell(pip_command) + # Add the cell to the notebook + nb.cells.insert(1, code_cell) + break + + return nb + + +def covert_to_nbs(scripts_dir: Path, notebooks_dir: Path) -> None: + # Loop through all Python files in the directory + for py_file_path in scripts_dir.rglob("*.py"): + # Construct the full paths + notebook_path = notebooks_dir / py_file_path.relative_to( + scripts_dir + ).with_suffix(".ipynb") + print(f"Converting {py_file_path} to notebook...") + notebook = jupytext.read(py_file_path) + notebook = add_installation_cell(notebook, py_file_path) + # Make cell ids deterministic to avoid changing ids everytime a notebook is (re)generated. + for i, cell in enumerate(notebook.cells): + cell.id = str(i) + jupytext.write(notebook, notebook_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "scripts_dir", + help="path to the directory containing the python scripts", + ) + parser.add_argument( + "notebooks_dir", + help="path to directory where the generated notebooks are stored", + ) + args = parser.parse_args() + covert_to_nbs(Path(args.scripts_dir), Path(args.notebooks_dir)) diff --git a/examples/notebooks/pytorch/aim.ipynb b/examples/notebooks/pytorch/aim.ipynb new file mode 100644 index 000000000..78f891190 --- /dev/null +++ b/examples/notebooks/pytorch/aim.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer\n", + "from lightly.transforms import AIMTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class AIM(nn.Module):\n", + " def __init__(self, vit):\n", + " super().__init__()\n", + " utils.initialize_2d_sine_cosine_positional_embedding(\n", + " pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token\n", + " )\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + " self.num_patches = vit.patch_embed.num_patches\n", + "\n", + " self.backbone = vit\n", + " self.projection_head = AIMPredictionHead(\n", + " input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1\n", + " )\n", + "\n", + " def forward(self, images):\n", + " batch_size = images.shape[0]\n", + "\n", + " mask = utils.random_prefix_mask(\n", + " size=(batch_size, self.num_patches),\n", + " max_prefix_length=self.num_patches - 1,\n", + " device=images.device,\n", + " )\n", + " features = self.backbone.forward_features(images, mask=mask)\n", + " # Add positional embedding before head.\n", + " features = self.backbone._pos_embed(features)\n", + " predictions = self.projection_head(features)\n", + "\n", + " # Convert images to patches and normalize them.\n", + " patches = utils.patchify(images, self.patch_size)\n", + " patches = utils.normalize_mean_var(patches, dim=-1)\n", + "\n", + " return predictions, patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "vit = MaskedCausalVisionTransformer(\n", + " img_size=224,\n", + " patch_size=32,\n", + " embed_dim=768,\n", + " depth=12,\n", + " num_heads=12,\n", + " qk_norm=False,\n", + " class_token=False,\n", + " no_embed_class=True,\n", + ")\n", + "model = AIM(vit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = AIMTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.MSELoss()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " images = views[0].to(device) # views contains only a single view\n", + " predictions, targets = model(images)\n", + " loss = criterion(predictions, targets)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/barlowtwins.ipynb b/examples/notebooks/pytorch/barlowtwins.ipynb new file mode 100644 index 000000000..b48286d0e --- /dev/null +++ b/examples/notebooks/pytorch/barlowtwins.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import BarlowTwinsLoss\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class BarlowTwins(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = BarlowTwins(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# BarlowTwins uses BYOL augmentations.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = BarlowTwinsLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0 = model(x0)\n", + " z1 = model(x1)\n", + " loss = criterion(z0, z1)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/byol.ipynb b/examples/notebooks/pytorch/byol.ipynb new file mode 100644 index 000000000..ed3a73e3f --- /dev/null +++ b/examples/notebooks/pytorch/byol.ipynb @@ -0,0 +1,227 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class BYOL(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = BYOLProjectionHead(512, 1024, 256)\n", + " self.prediction_head = BYOLPredictionHead(256, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " return p\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = BYOL(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NegativeCosineSimilarity()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)\n", + " update_momentum(\n", + " model.projection_head, model.projection_head_momentum, m=momentum_val\n", + " )\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " p0 = model(x0)\n", + " z0 = model.forward_momentum(x0)\n", + " p1 = model(x1)\n", + " z1 = model.forward_momentum(x1)\n", + " loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/dcl.ipynb b/examples/notebooks/pytorch/dcl.ipynb new file mode 100644 index 000000000..ab9ad739a --- /dev/null +++ b/examples/notebooks/pytorch/dcl.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DCLLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class DCL(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SimCLRProjectionHead(512, 512, 128)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = DCL(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = DCLLoss()\n", + "# or use the weighted DCLW loss:\n", + "# criterion = DCLWLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0 = model(x0)\n", + " z1 = model(x1)\n", + " loss = criterion(z0, z1)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/densecl.ipynb b/examples/notebooks/pytorch/densecl.ipynb new file mode 100644 index 000000000..cdd1b10c9 --- /dev/null +++ b/examples/notebooks/pytorch/densecl.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import DenseCLProjectionHead\n", + "from lightly.transforms import DenseCLTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DenseCL(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head_global = DenseCLProjectionHead(512, 512, 128)\n", + " self.projection_head_local = DenseCLProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_global_momentum = copy.deepcopy(\n", + " self.projection_head_global\n", + " )\n", + " self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1))\n", + "\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_global_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_local_momentum)\n", + "\n", + " def forward(self, x):\n", + " query_features = self.backbone(x)\n", + " query_global = self.pool(query_features).flatten(start_dim=1)\n", + " query_global = self.projection_head_global(query_global)\n", + " query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " query_local = self.projection_head_local(query_features)\n", + " # Shapes: (B, H*W, C), (B, D), (B, H*W, D)\n", + " return query_features, query_global, query_local\n", + "\n", + " @torch.no_grad()\n", + " def forward_momentum(self, x):\n", + " key_features = self.backbone(x)\n", + " key_global = self.pool(key_features).flatten(start_dim=1)\n", + " key_global = self.projection_head_global(key_global)\n", + " key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " key_local = self.projection_head_local(key_features)\n", + " return key_features, key_global, key_local" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + "model = DenseCL(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DenseCLTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion_global = NTXentLoss(memory_bank_size=(4096, 128))\n", + "criterion_local = NTXentLoss(memory_bank_size=(4096, 128))\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " x_query, x_key = batch[0]\n", + " utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)\n", + " utils.update_momentum(\n", + " model.projection_head_global,\n", + " model.projection_head_global_momentum,\n", + " m=momentum,\n", + " )\n", + " utils.update_momentum(\n", + " model.projection_head_local,\n", + " model.projection_head_local_momentum,\n", + " m=momentum,\n", + " )\n", + " x_query = x_query.to(device)\n", + " x_key = x_key.to(device)\n", + " query_features, query_global, query_local = model(x_query)\n", + " key_features, key_global, key_local = model.forward_momentum(x_key)\n", + "\n", + " key_local = utils.select_most_similar(query_features, key_features, key_local)\n", + " query_local = query_local.flatten(end_dim=1)\n", + " key_local = key_local.flatten(end_dim=1)\n", + "\n", + " loss_global = criterion_global(query_global, key_global)\n", + " loss_local = criterion_local(query_local, key_local)\n", + " lambda_ = 0.5\n", + " loss = (1 - lambda_) * loss_global + lambda_ * loss_local\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/dino.ipynb b/examples/notebooks/pytorch/dino.ipynb new file mode 100644 index 000000000..6ac8da549 --- /dev/null +++ b/examples/notebooks/pytorch/dino.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss\n", + "from lightly.models.modules import DINOProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DINO(torch.nn.Module):\n", + " def __init__(self, backbone, input_dim):\n", + " super().__init__()\n", + " self.student_backbone = backbone\n", + " self.student_head = DINOProjectionHead(\n", + " input_dim, 512, 64, 2048, freeze_last_layer=1\n", + " )\n", + " self.teacher_backbone = copy.deepcopy(backbone)\n", + " self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)\n", + " deactivate_requires_grad(self.teacher_backbone)\n", + " deactivate_requires_grad(self.teacher_head)\n", + "\n", + " def forward(self, x):\n", + " y = self.student_backbone(x).flatten(start_dim=1)\n", + " z = self.student_head(y)\n", + " return z\n", + "\n", + " def forward_teacher(self, x):\n", + " y = self.teacher_backbone(x).flatten(start_dim=1)\n", + " z = self.teacher_head(y)\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "input_dim = 512\n", + "# instead of a resnet you can also use a vision transformer backbone as in the\n", + "# original paper (you might have to reduce the batch size in this case):\n", + "# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)\n", + "# input_dim = backbone.embed_dim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINO(backbone, input_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = DINOLoss(\n", + " output_dim=2048,\n", + " warmup_teacher_temp_epochs=5,\n", + ")\n", + "# move loss to correct device because it also contains parameters\n", + "criterion = criterion.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)\n", + " update_momentum(model.student_head, model.teacher_head, m=momentum_val)\n", + " views = [view.to(device) for view in views]\n", + " global_views = views[:2]\n", + " teacher_out = [model.forward_teacher(view) for view in global_views]\n", + " student_out = [model.forward(view) for view in views]\n", + " loss = criterion(teacher_out, student_out, epoch=epoch)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " # We only cancel gradients of student head.\n", + " model.student_head.cancel_last_layer_gradients(current_epoch=epoch)\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/fastsiam.ipynb b/examples/notebooks/pytorch/fastsiam.ipynb new file mode 100644 index 000000000..acab7236f --- /dev/null +++ b/examples/notebooks/pytorch/fastsiam.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import FastSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class FastSiam(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = FastSiam(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = FastSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NegativeCosineSimilarity()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " features = [model(view.to(device)) for view in views]\n", + " zs = torch.stack([z for z, _ in features])\n", + " ps = torch.stack([p for _, p in features])\n", + "\n", + " loss = 0.0\n", + " for i in range(len(views)):\n", + " mask = torch.arange(len(views), device=device) != i\n", + " loss += criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)\n", + "\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/ijepa.ipynb b/examples/notebooks/pytorch/ijepa.ipynb new file mode 100644 index 000000000..74fadfdfd --- /dev/null +++ b/examples/notebooks/pytorch/ijepa.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly[timm]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly[timm]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.data.collate import IJEPAMaskCollator\n", + "from lightly.models import utils\n", + "from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor\n", + "from lightly.transforms.ijepa_transform import IJEPATransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class IJEPA(nn.Module):\n", + " def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):\n", + " super().__init__()\n", + " self.encoder = IJEPABackbone.from_vit(vit_encoder)\n", + " self.predictor = IJEPAPredictor.from_vit_encoder(\n", + " vit_predictor.encoder,\n", + " (vit_predictor.image_size // vit_predictor.patch_size) ** 2,\n", + " )\n", + " self.target_encoder = copy.deepcopy(self.encoder)\n", + " self.momentum_scheduler = momentum_scheduler\n", + "\n", + " def forward_target(self, imgs, masks_enc, masks_pred):\n", + " with torch.no_grad():\n", + " h = self.target_encoder(imgs)\n", + " h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim\n", + " B = len(h)\n", + " # -- create targets (masked regions of h)\n", + " h = utils.apply_masks(h, masks_pred)\n", + " h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))\n", + " return h\n", + "\n", + " def forward_context(self, imgs, masks_enc, masks_pred):\n", + " z = self.encoder(imgs, masks_enc)\n", + " z = self.predictor(z, masks_enc, masks_pred)\n", + " return z\n", + "\n", + " def forward(self, imgs, masks_enc, masks_pred):\n", + " z = self.forward_context(imgs, masks_enc, masks_pred)\n", + " h = self.forward_target(imgs, masks_enc, masks_pred)\n", + " return z, h\n", + "\n", + " def update_target_encoder(\n", + " self,\n", + " ):\n", + " with torch.no_grad():\n", + " m = next(self.momentum_scheduler)\n", + " for param_q, param_k in zip(\n", + " self.encoder.parameters(), self.target_encoder.parameters()\n", + " ):\n", + " param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "collator = IJEPAMaskCollator(\n", + " input_size=(224, 224),\n", + " patch_size=32,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = IJEPATransform()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# we ignore object detection annotations by setting target_transform to return 0\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "data_loader = torch.utils.data.DataLoader(\n", + " dataset, collate_fn=collator, batch_size=10, persistent_workers=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "ema = (0.996, 1.0)\n", + "ipe_scale = 1.0\n", + "ipe = len(data_loader)\n", + "num_epochs = 10\n", + "momentum_scheduler = (\n", + " ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)\n", + " for i in range(int(ipe * num_epochs * ipe_scale) + 1)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "vit_for_predictor = torchvision.models.vit_b_32(pretrained=False)\n", + "vit_for_embedder = torchvision.models.vit_b_32(pretrained=False)\n", + "model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.SmoothL1Loss()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(num_epochs):\n", + " total_loss = 0\n", + " for udata, masks_enc, masks_pred in tqdm(data_loader):\n", + "\n", + " def load_imgs():\n", + " # -- unsupervised imgs\n", + " imgs = udata[0].to(device, non_blocking=True)\n", + " masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]\n", + " masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]\n", + " return (imgs, masks_1, masks_2)\n", + "\n", + " imgs, masks_enc, masks_pred = load_imgs()\n", + " z, h = model(imgs, masks_enc, masks_pred)\n", + " loss = criterion(z, h)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " model.update_target_encoder()\n", + "\n", + " avg_loss = total_loss / len(data_loader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/mae.ipynb b/examples/notebooks/pytorch/mae.ipynb new file mode 100644 index 000000000..ee50e19b8 --- /dev/null +++ b/examples/notebooks/pytorch/mae.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly[timm]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly[timm]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_base_patch32_224\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM\n", + "from lightly.transforms import MAETransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class MAE(nn.Module):\n", + " def __init__(self, vit):\n", + " super().__init__()\n", + "\n", + " decoder_dim = 512\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + "\n", + " self.backbone = MaskedVisionTransformerTIMM(vit=vit)\n", + " self.sequence_length = self.backbone.sequence_length\n", + " self.decoder = MAEDecoderTIMM(\n", + " num_patches=vit.patch_embed.num_patches,\n", + " patch_size=self.patch_size,\n", + " embed_dim=vit.embed_dim,\n", + " decoder_embed_dim=decoder_dim,\n", + " decoder_depth=1,\n", + " decoder_num_heads=16,\n", + " mlp_ratio=4.0,\n", + " proj_drop_rate=0.0,\n", + " attn_drop_rate=0.0,\n", + " )\n", + "\n", + " def forward_encoder(self, images, idx_keep=None):\n", + " return self.backbone.encode(images=images, idx_keep=idx_keep)\n", + "\n", + " def forward_decoder(self, x_encoded, idx_keep, idx_mask):\n", + " # build decoder input\n", + " batch_size = x_encoded.shape[0]\n", + " x_decode = self.decoder.embed(x_encoded)\n", + " x_masked = utils.repeat_token(\n", + " self.decoder.mask_token, (batch_size, self.sequence_length)\n", + " )\n", + " x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))\n", + "\n", + " # decoder forward pass\n", + " x_decoded = self.decoder.decode(x_masked)\n", + "\n", + " # predict pixel values for masked tokens\n", + " x_pred = utils.get_at_index(x_decoded, idx_mask)\n", + " x_pred = self.decoder.predict(x_pred)\n", + " return x_pred\n", + "\n", + " def forward(self, images):\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + " x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)\n", + " x_pred = self.forward_decoder(\n", + " x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask\n", + " )\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + " return x_pred, target" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "vit = vit_base_patch32_224()\n", + "model = MAE(vit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.MSELoss()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " images = views[0].to(device) # views contains only a single view\n", + " predictions, targets = model(images)\n", + " loss = criterion(predictions, targets)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/mmcr.ipynb b/examples/notebooks/pytorch/mmcr.ipynb new file mode 100644 index 000000000..1798008ad --- /dev/null +++ b/examples/notebooks/pytorch/mmcr.ipynb @@ -0,0 +1,217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MMCRLoss\n", + "from lightly.models.modules import MMCRProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.mmcr_transform import MMCRTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MMCR(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = MMCRProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = MMCR(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = MMCRLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)\n", + " update_momentum(\n", + " model.projection_head, model.projection_head_momentum, m=momentum_val\n", + " )\n", + " z_o = [model(x.to(device)) for x in batch[0]]\n", + " z_m = [model.forward_momentum(x.to(device)) for x in batch[0]]\n", + "\n", + " # Switch dimensions to (batch_size, k, embedding_size)\n", + " z_o = torch.stack(z_o, dim=1)\n", + " z_m = torch.stack(z_m, dim=1)\n", + "\n", + " loss = criterion(z_o, z_m)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/moco.ipynb b/examples/notebooks/pytorch/moco.ipynb new file mode 100644 index 000000000..d73691d59 --- /dev/null +++ b/examples/notebooks/pytorch/moco.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import MoCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.moco_transform import MoCoV2Transform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MoCo(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = MoCoProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def forward(self, x):\n", + " query = self.backbone(x).flatten(start_dim=1)\n", + " query = self.projection_head(query)\n", + " return query\n", + "\n", + " def forward_momentum(self, x):\n", + " key = self.backbone_momentum(x).flatten(start_dim=1)\n", + " key = self.projection_head_momentum(key).detach()\n", + " return key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = MoCo(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MoCoV2Transform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NTXentLoss(memory_bank_size=(4096, 128))\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " x_query, x_key = batch[0]\n", + " update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)\n", + " update_momentum(\n", + " model.projection_head, model.projection_head_momentum, m=momentum_val\n", + " )\n", + " x_query = x_query.to(device)\n", + " x_key = x_key.to(device)\n", + " query = model(x_query)\n", + " key = model.forward_momentum(x_key)\n", + " loss = criterion(query, key)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/msn.ipynb b/examples/notebooks/pytorch/msn.ipynb new file mode 100644 index 000000000..ed273643d --- /dev/null +++ b/examples/notebooks/pytorch/msn.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms.msn_transform import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class MSN(nn.Module):\n", + " def __init__(self, vit):\n", + " super().__init__()\n", + "\n", + " self.mask_ratio = 0.15\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " self.projection_head = MSNProjectionHead(input_dim=384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + "\n", + " def forward(self, images):\n", + " out = self.backbone(images=images)\n", + " return self.projection_head(out)\n", + "\n", + " def forward_masked(self, images):\n", + " batch_size, _, _, width = images.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + " out = self.anchor_backbone(images=images, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# ViT small configuration (ViT-S/16)\n", + "vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + ")\n", + "model = MSN(vit)\n", + "# or use a torchvision ViT backbone:\n", + "# vit = torchvision.models.vit_b_32(pretrained=False)\n", + "# model = MSN(vit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = MSNLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "params = [\n", + " *list(model.anchor_backbone.parameters()),\n", + " *list(model.anchor_projection_head.parameters()),\n", + " model.prototypes,\n", + "]\n", + "optimizer = torch.optim.AdamW(params, lr=1.5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " utils.update_momentum(model.anchor_backbone, model.backbone, 0.996)\n", + " utils.update_momentum(\n", + " model.anchor_projection_head, model.projection_head, 0.996\n", + " )\n", + "\n", + " views = [view.to(device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = model.backbone(targets)\n", + " targets_out = model.projection_head(targets_out)\n", + " anchors_out = model.forward_masked(anchors)\n", + " anchors_focal_out = model.forward_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = criterion(anchors_out, targets_out, model.prototypes.data)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/nnclr.ipynb b/examples/notebooks/pytorch/nnclr.ipynb new file mode 100644 index 000000000..f916b2b2d --- /dev/null +++ b/examples/notebooks/pytorch/nnclr.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import (\n", + " NNCLRPredictionHead,\n", + " NNCLRProjectionHead,\n", + " NNMemoryBankModule,\n", + ")\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class NNCLR(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = NNCLRProjectionHead(512, 512, 128)\n", + " self.prediction_head = NNCLRPredictionHead(128, 512, 128)\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = NNCLR(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "memory_bank = NNMemoryBankModule(size=(4096, 128))\n", + "memory_bank.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NTXentLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0, p0 = model(x0)\n", + " z1, p1 = model(x1)\n", + " z0 = memory_bank(z0, update=False)\n", + " z1 = memory_bank(z1, update=True)\n", + " loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/pmsn.ipynb b/examples/notebooks/pytorch/pmsn.ipynb new file mode 100644 index 000000000..7555f11bf --- /dev/null +++ b/examples/notebooks/pytorch/pmsn.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import PMSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class PMSN(nn.Module):\n", + " def __init__(self, vit):\n", + " super().__init__()\n", + "\n", + " self.mask_ratio = 0.15\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " self.projection_head = MSNProjectionHead(384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + "\n", + " def forward(self, images):\n", + " out = self.backbone(images=images)\n", + " return self.projection_head(out)\n", + "\n", + " def forward_masked(self, images):\n", + " batch_size, _, _, width = images.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + " out = self.anchor_backbone(images=images, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# ViT small configuration (ViT-S/16)\n", + "vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + ")\n", + "model = PMSN(vit)\n", + "# # or use a torchvision ViT backbone:\n", + "# vit = torchvision.models.vit_b_32(pretrained=False)\n", + "# model = PMSN(vit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = PMSNLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "params = [\n", + " *list(model.anchor_backbone.parameters()),\n", + " *list(model.anchor_projection_head.parameters()),\n", + " model.prototypes,\n", + "]\n", + "optimizer = torch.optim.AdamW(params, lr=1.5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " utils.update_momentum(model.anchor_backbone, model.backbone, 0.996)\n", + " utils.update_momentum(\n", + " model.anchor_projection_head, model.projection_head, 0.996\n", + " )\n", + "\n", + " views = [view.to(device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = model.backbone(images=targets)\n", + " targets_out = model.projection_head(targets_out)\n", + " anchors_out = model.forward_masked(anchors)\n", + " anchors_focal_out = model.forward_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = criterion(anchors_out, targets_out, model.prototypes.data)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/simclr.ipynb b/examples/notebooks/pytorch/simclr.ipynb new file mode 100644 index 000000000..9db02c185 --- /dev/null +++ b/examples/notebooks/pytorch/simclr.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SimCLR(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SimCLRProjectionHead(512, 512, 128)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SimCLR(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32, gaussian_blur=0.0)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NTXentLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0 = model(x0)\n", + " z1 = model(x1)\n", + " loss = criterion(z0, z1)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/simmim.ipynb b/examples/notebooks/pytorch/simmim.ipynb new file mode 100644 index 000000000..fac4f8478 --- /dev/null +++ b/examples/notebooks/pytorch/simmim.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules.masked_vision_transformer_torchvision import (\n", + " MaskedVisionTransformerTorchvision,\n", + ")\n", + "from lightly.transforms.mae_transform import MAETransform # Same transform as MAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class SimMIM(nn.Module):\n", + " def __init__(self, vit):\n", + " super().__init__()\n", + "\n", + " decoder_dim = vit.hidden_dim\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_size\n", + " self.sequence_length = vit.seq_length\n", + "\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + "\n", + " # the decoder is a simple linear layer\n", + " self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3)\n", + "\n", + " def forward_encoder(self, images, batch_size, idx_mask):\n", + " # pass all the tokens to the encoder, both masked and non masked ones\n", + " return self.backbone.encode(images=images, idx_mask=idx_mask)\n", + "\n", + " def forward_decoder(self, x_encoded):\n", + " return self.decoder(x_encoded)\n", + "\n", + " def forward(self, images):\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + "\n", + " # Encoding...\n", + " x_encoded = self.forward_encoder(images, batch_size, idx_mask)\n", + " x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)\n", + "\n", + " # Decoding...\n", + " x_out = self.forward_decoder(x_encoded_masked)\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + "\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + "\n", + " return x_out, target" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "vit = torchvision.models.vit_b_32(pretrained=False)\n", + "model = SimMIM(vit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=8,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# L1 loss as paper suggestion\n", + "criterion = nn.L1Loss()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " images = views[0].to(device) # views contains only a single view\n", + " predictions, targets = model(images)\n", + "\n", + " loss = criterion(predictions, targets)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/simsiam.ipynb b/examples/notebooks/pytorch/simsiam.ipynb new file mode 100644 index 000000000..fa605863a --- /dev/null +++ b/examples/notebooks/pytorch/simsiam.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import SimSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SimSiam(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SimSiam(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = NegativeCosineSimilarity()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0, p0 = model(x0)\n", + " z1, p1 = model(x1)\n", + " loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/smog.ipynb b/examples/notebooks/pytorch/smog.ipynb new file mode 100644 index 000000000..98cea3d66 --- /dev/null +++ b/examples/notebooks/pytorch/smog.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from sklearn.cluster import KMeans\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules.heads import (\n", + " SMoGPredictionHead,\n", + " SMoGProjectionHead,\n", + " SMoGPrototypes,\n", + ")\n", + "from lightly.models.modules.memory_bank import MemoryBankModule\n", + "from lightly.transforms.smog_transform import SMoGTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class SMoGModel(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = SMoGProjectionHead(512, 2048, 128)\n", + " self.prediction_head = SMoGPredictionHead(128, 2048, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.n_groups = 300\n", + " self.smog = SMoGPrototypes(\n", + " group_features=torch.rand(self.n_groups, 128), beta=0.99\n", + " )\n", + "\n", + " def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:\n", + " # clusters the features using sklearn\n", + " # (note: faiss is probably more efficient)\n", + " features = features.cpu().numpy()\n", + " kmeans = KMeans(self.n_groups).fit(features)\n", + " clustered = torch.from_numpy(kmeans.cluster_centers_).float()\n", + " clustered = torch.nn.functional.normalize(clustered, dim=1)\n", + " return clustered\n", + "\n", + " def reset_group_features(self, memory_bank):\n", + " # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n", + " features = memory_bank.bank\n", + " group_features = self._cluster_features(features.t())\n", + " self.smog.set_group_features(group_features)\n", + "\n", + " def reset_momentum_weights(self):\n", + " # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def forward(self, x):\n", + " features = self.backbone(x).flatten(start_dim=1)\n", + " encoded = self.projection_head(features)\n", + " predicted = self.prediction_head(encoded)\n", + " return encoded, predicted\n", + "\n", + " def forward_momentum(self, x):\n", + " features = self.backbone_momentum(x).flatten(start_dim=1)\n", + " encoded = self.projection_head_momentum(features)\n", + " return encoded" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 256" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SMoGModel(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# memory bank because we reset the group features every 300 iterations\n", + "memory_bank_size = 300 * batch_size\n", + "memory_bank = MemoryBankModule(size=(memory_bank_size, 128))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SMoGTransform(\n", + " crop_sizes=(32, 32),\n", + " crop_counts=(1, 1),\n", + " gaussian_blur_probs=(0.0, 0.0),\n", + " crop_min_scales=(0.2, 0.2),\n", + " crop_max_scales=(1.0, 1.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.SGD(\n", + " model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "global_step = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch_idx, batch in enumerate(dataloader):\n", + " (x0, x1) = batch[0]\n", + "\n", + " if batch_idx % 2:\n", + " # swap batches every second iteration\n", + " x1, x0 = x0, x1\n", + "\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + "\n", + " if global_step > 0 and global_step % 300 == 0:\n", + " # reset group features and weights every 300 iterations\n", + " model.reset_group_features(memory_bank=memory_bank)\n", + " model.reset_momentum_weights()\n", + " else:\n", + " # update momentum\n", + " utils.update_momentum(model.backbone, model.backbone_momentum, 0.99)\n", + " utils.update_momentum(\n", + " model.projection_head, model.projection_head_momentum, 0.99\n", + " )\n", + "\n", + " x0_encoded, x0_predicted = model(x0)\n", + " x1_encoded = model.forward_momentum(x1)\n", + "\n", + " # update group features and get group assignments\n", + " assignments = model.smog.assign_groups(x1_encoded)\n", + " group_features = model.smog.get_updated_group_features(x0_encoded)\n", + " logits = model.smog(x0_predicted, group_features, temperature=0.1)\n", + " model.smog.set_group_features(group_features)\n", + "\n", + " loss = criterion(logits, assignments)\n", + "\n", + " # use memory bank to periodically reset the group features with k-means\n", + " memory_bank(x0_encoded, update=True)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " global_step += 1\n", + " total_loss += loss.detach()\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/swav.ipynb b/examples/notebooks/pytorch/swav.ipynb new file mode 100644 index 000000000..d2b4c46b1 --- /dev/null +++ b/examples/notebooks/pytorch/swav.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import SwaVLoss\n", + "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n", + "from lightly.transforms.swav_transform import SwaVTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SwaV(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SwaVProjectionHead(512, 512, 128)\n", + " self.prototypes = SwaVPrototypes(128, n_prototypes=512)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " x = self.projection_head(x)\n", + " x = nn.functional.normalize(x, dim=1, p=2)\n", + " p = self.prototypes(x)\n", + " return p" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SwaV(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SwaVTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=128,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = SwaVLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " model.prototypes.normalize()\n", + " multi_crop_features = [model(view.to(device)) for view in views]\n", + " high_resolution = multi_crop_features[:2]\n", + " low_resolution = multi_crop_features[2:]\n", + " loss = criterion(high_resolution, low_resolution)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/swav_queue.ipynb b/examples/notebooks/pytorch/swav_queue.ipynb new file mode 100644 index 000000000..d39175200 --- /dev/null +++ b/examples/notebooks/pytorch/swav_queue.ipynb @@ -0,0 +1,231 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import SwaVLoss\n", + "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n", + "from lightly.models.modules.memory_bank import MemoryBankModule\n", + "from lightly.transforms.swav_transform import SwaVTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SwaV(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = SwaVProjectionHead(512, 512, 128)\n", + " self.prototypes = SwaVPrototypes(128, 512, 1)\n", + "\n", + " self.start_queue_at_epoch = 2\n", + " self.queues = nn.ModuleList(\n", + " [MemoryBankModule(size=(3840, 128)) for _ in range(2)]\n", + " )\n", + "\n", + " def forward(self, high_resolution, low_resolution, epoch):\n", + " self.prototypes.normalize()\n", + "\n", + " high_resolution_features = [self._subforward(x) for x in high_resolution]\n", + " low_resolution_features = [self._subforward(x) for x in low_resolution]\n", + "\n", + " high_resolution_prototypes = [\n", + " self.prototypes(x, epoch) for x in high_resolution_features\n", + " ]\n", + " low_resolution_prototypes = [\n", + " self.prototypes(x, epoch) for x in low_resolution_features\n", + " ]\n", + " queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)\n", + "\n", + " return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes\n", + "\n", + " def _subforward(self, input):\n", + " features = self.backbone(input).flatten(start_dim=1)\n", + " features = self.projection_head(features)\n", + " features = nn.functional.normalize(features, dim=1, p=2)\n", + " return features\n", + "\n", + " @torch.no_grad()\n", + " def _get_queue_prototypes(self, high_resolution_features, epoch):\n", + " if len(high_resolution_features) != len(self.queues):\n", + " raise ValueError(\n", + " f\"The number of queues ({len(self.queues)}) should be equal to the number of high \"\n", + " f\"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly.\"\n", + " )\n", + "\n", + " # Get the queue features\n", + " queue_features = []\n", + " for i in range(len(self.queues)):\n", + " _, features = self.queues[i](high_resolution_features[i], update=True)\n", + " # Queue features are in (num_ftrs X queue_length) shape, while the high res\n", + " # features are in (batch_size X num_ftrs). Swap the axes for interoperability.\n", + " features = torch.permute(features, (1, 0))\n", + " queue_features.append(features)\n", + "\n", + " # If loss calculation with queue prototypes starts at a later epoch,\n", + " # just queue the features and return None instead of queue prototypes.\n", + " if self.start_queue_at_epoch > 0 and epoch < self.start_queue_at_epoch:\n", + " return None\n", + "\n", + " # Assign prototypes\n", + " queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]\n", + " return queue_prototypes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SwaV(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SwaVTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=128,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = SwaVLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " views = batch[0]\n", + " views = [view.to(device) for view in views]\n", + " high_resolution, low_resolution = views[:2], views[2:]\n", + " high_resolution, low_resolution, queue = model(\n", + " high_resolution, low_resolution, epoch\n", + " )\n", + " loss = criterion(high_resolution, low_resolution, queue)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/tico.ipynb b/examples/notebooks/pytorch/tico.ipynb new file mode 100644 index 000000000..54dcdbb34 --- /dev/null +++ b/examples/notebooks/pytorch/tico.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss.tico_loss import TiCoLoss\n", + "from lightly.models.modules.heads import TiCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.tico_transform import (\n", + " TiCoTransform,\n", + " TiCoView1Transform,\n", + " TiCoView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class TiCo(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + "\n", + " self.backbone = backbone\n", + " self.projection_head = TiCoProjectionHead(512, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = TiCo(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# TiCo uses the same augmentations as BYOL.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = TiCoTransform(\n", + " view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = TiCoLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(epochs):\n", + " total_loss = 0\n", + " momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)\n", + " update_momentum(\n", + " model.projection_head, model.projection_head_momentum, m=momentum_val\n", + " )\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0 = model(x0)\n", + " z1 = model.forward_momentum(x1)\n", + " loss = criterion(z0, z1)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/vicreg.ipynb b/examples/notebooks/pytorch/vicreg.ipynb new file mode 100644 index 000000000..7af4c1c38 --- /dev/null +++ b/examples/notebooks/pytorch/vicreg.ipynb @@ -0,0 +1,173 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "## The projection head is the same as the Barlow Twins one\n", + "from lightly.loss import VICRegLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "## The projection head is the same as the Barlow Twins one\n", + "from lightly.loss.vicreg_loss import VICRegLoss\n", + "from lightly.models.modules.heads import VICRegProjectionHead\n", + "from lightly.transforms.vicreg_transform import VICRegTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class VICReg(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = VICRegProjectionHead(\n", + " input_dim=512,\n", + " hidden_dim=2048,\n", + " output_dim=2048,\n", + " num_layers=2,\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = VICReg(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")\n", + "criterion = VICRegLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x0, x1 = batch[0]\n", + " x0 = x0.to(device)\n", + " x1 = x1.to(device)\n", + " z0 = model(x0)\n", + " z1 = model(x1)\n", + " loss = criterion(z0, z1)\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch/vicregl.ipynb b/examples/notebooks/pytorch/vicregl.ipynb new file mode 100644 index 000000000..d7cc1dcec --- /dev/null +++ b/examples/notebooks/pytorch/vicregl.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import VICRegLLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "## The global projection head is the same as the Barlow Twins one\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.models.modules.heads import VicRegLLocalProjectionHead\n", + "from lightly.transforms.vicregl_transform import VICRegLTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class VICRegL(nn.Module):\n", + " def __init__(self, backbone):\n", + " super().__init__()\n", + " self.backbone = backbone\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + " self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)\n", + " self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x)\n", + " y = self.average_pool(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)\n", + " z_local = self.local_projection_head(y_local)\n", + " return z, z_local" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + "model = VICRegL(backbone)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegLTransform(n_local_views=0)\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = VICRegLLoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Starting Training\")\n", + "for epoch in range(10):\n", + " total_loss = 0\n", + " for views_and_grids, _ in dataloader:\n", + " views_and_grids = [x.to(device) for x in views_and_grids]\n", + " views = views_and_grids[: len(views_and_grids) // 2]\n", + " grids = views_and_grids[len(views_and_grids) // 2 :]\n", + " features = [model(view) for view in views]\n", + " loss = criterion(\n", + " global_view_features=features[:2],\n", + " global_view_grids=grids[:2],\n", + " local_view_features=features[2:],\n", + " local_view_grids=grids[2:],\n", + " )\n", + " total_loss += loss.detach()\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/aim.ipynb b/examples/notebooks/pytorch_lightning/aim.ipynb new file mode 100644 index 000000000..c7f3b54eb --- /dev/null +++ b/examples/notebooks/pytorch_lightning/aim.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer\n", + "from lightly.transforms import AIMTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class AIM(pl.LightningModule):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " vit = MaskedCausalVisionTransformer(\n", + " img_size=224,\n", + " patch_size=32,\n", + " embed_dim=768,\n", + " depth=12,\n", + " num_heads=12,\n", + " qk_norm=False,\n", + " class_token=False,\n", + " no_embed_class=True,\n", + " )\n", + " utils.initialize_2d_sine_cosine_positional_embedding(\n", + " pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token\n", + " )\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + " self.num_patches = vit.patch_embed.num_patches\n", + "\n", + " self.backbone = vit\n", + " self.projection_head = AIMPredictionHead(\n", + " input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1\n", + " )\n", + "\n", + " self.criterion = nn.MSELoss()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views, targets = batch[0], batch[1]\n", + " images = views[0] # AIM has only a single view\n", + " batch_size = images.shape[0]\n", + "\n", + " mask = utils.random_prefix_mask(\n", + " size=(batch_size, self.num_patches),\n", + " max_prefix_length=self.num_patches - 1,\n", + " device=images.device,\n", + " )\n", + " features = self.backbone.forward_features(images, mask=mask)\n", + " # Add positional embedding before head.\n", + " features = self.backbone._pos_embed(features)\n", + " predictions = self.projection_head(features)\n", + "\n", + " # Convert images to patches and normalize them.\n", + " patches = utils.patchify(images, self.patch_size)\n", + " patches = utils.normalize_mean_var(patches, dim=-1)\n", + "\n", + " loss = self.criterion(predictions, patches)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = AIM()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = AIMTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/barlowtwins.ipynb b/examples/notebooks/pytorch_lightning/barlowtwins.ipynb new file mode 100644 index 000000000..03d12d046 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/barlowtwins.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import BarlowTwinsLoss\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class BarlowTwins(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + " self.criterion = BarlowTwinsLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = BarlowTwins()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# BarlowTwins uses BYOL augmentations.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/byol.ipynb b/examples/notebooks/pytorch_lightning/byol.ipynb new file mode 100644 index 000000000..9c74957fa --- /dev/null +++ b/examples/notebooks/pytorch_lightning/byol.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class BYOL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = BYOLProjectionHead(512, 1024, 256)\n", + " self.prediction_head = BYOLPredictionHead(256, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " return p\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " (x0, x1) = batch[0]\n", + " p0 = self.forward(x0)\n", + " z0 = self.forward_momentum(x0)\n", + " p1 = self.forward(x1)\n", + " z1 = self.forward_momentum(x1)\n", + " loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = BYOL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/dcl.ipynb b/examples/notebooks/pytorch_lightning/dcl.ipynb new file mode 100644 index 000000000..ec3b03319 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/dcl.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DCLLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class DCL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimCLRProjectionHead(512, 2048, 2048)\n", + " self.criterion = DCLLoss()\n", + " # or use the weighted DCLW loss:\n", + " # self.criterion = DCLWLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = DCL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")\n", + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/densecl.ipynb b/examples/notebooks/pytorch_lightning/densecl.ipynb new file mode 100644 index 000000000..b17cf74d6 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/densecl.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import DenseCLProjectionHead\n", + "from lightly.transforms import DenseCLTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DenseCL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + " self.projection_head_global = DenseCLProjectionHead(512, 512, 128)\n", + " self.projection_head_local = DenseCLProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_global_momentum = copy.deepcopy(\n", + " self.projection_head_global\n", + " )\n", + " self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1))\n", + "\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_global_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_local_momentum)\n", + "\n", + " self.criterion_global = NTXentLoss(memory_bank_size=(4096, 128))\n", + " self.criterion_local = NTXentLoss(memory_bank_size=(4096, 128))\n", + "\n", + " def forward(self, x):\n", + " query_features = self.backbone(x)\n", + " query_global = self.pool(query_features).flatten(start_dim=1)\n", + " query_global = self.projection_head_global(query_global)\n", + " query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " query_local = self.projection_head_local(query_features)\n", + " # Shapes: (B, H*W, C), (B, D), (B, H*W, D)\n", + " return query_features, query_global, query_local\n", + "\n", + " @torch.no_grad()\n", + " def forward_momentum(self, x):\n", + " key_features = self.backbone(x)\n", + " key_global = self.pool(key_features).flatten(start_dim=1)\n", + " key_global = self.projection_head_global(key_global)\n", + " key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " key_local = self.projection_head_local(key_features)\n", + " return key_features, key_global, key_local\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)\n", + " utils.update_momentum(\n", + " model.projection_head_global,\n", + " model.projection_head_global_momentum,\n", + " m=momentum,\n", + " )\n", + " utils.update_momentum(\n", + " model.projection_head_local,\n", + " model.projection_head_local_momentum,\n", + " m=momentum,\n", + " )\n", + "\n", + " x_query, x_key = batch[0]\n", + " query_features, query_global, query_local = self(x_query)\n", + " key_features, key_global, key_local = self.forward_momentum(x_key)\n", + "\n", + " key_local = utils.select_most_similar(query_features, key_features, key_local)\n", + " query_local = query_local.flatten(end_dim=1)\n", + " key_local = key_local.flatten(end_dim=1)\n", + "\n", + " loss_global = self.criterion_global(query_global, key_global)\n", + " loss_local = self.criterion_local(query_local, key_local)\n", + " lambda_ = 0.5\n", + " loss = (1 - lambda_) * loss_global + lambda_ * loss_local\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = DenseCL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DenseCLTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/dino.ipynb b/examples/notebooks/pytorch_lightning/dino.ipynb new file mode 100644 index 000000000..d43a6fd5a --- /dev/null +++ b/examples/notebooks/pytorch_lightning/dino.ipynb @@ -0,0 +1,204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss\n", + "from lightly.models.modules import DINOProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DINO(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " input_dim = 512\n", + " # instead of a resnet you can also use a vision transformer backbone as in the\n", + " # original paper (you might have to reduce the batch size in this case):\n", + " # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)\n", + " # input_dim = backbone.embed_dim\n", + "\n", + " self.student_backbone = backbone\n", + " self.student_head = DINOProjectionHead(\n", + " input_dim, 512, 64, 2048, freeze_last_layer=1\n", + " )\n", + " self.teacher_backbone = copy.deepcopy(backbone)\n", + " self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)\n", + " deactivate_requires_grad(self.teacher_backbone)\n", + " deactivate_requires_grad(self.teacher_head)\n", + "\n", + " self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)\n", + "\n", + " def forward(self, x):\n", + " y = self.student_backbone(x).flatten(start_dim=1)\n", + " z = self.student_head(y)\n", + " return z\n", + "\n", + " def forward_teacher(self, x):\n", + " y = self.teacher_backbone(x).flatten(start_dim=1)\n", + " z = self.teacher_head(y)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)\n", + " update_momentum(self.student_head, self.teacher_head, m=momentum)\n", + " views = batch[0]\n", + " views = [view.to(self.device) for view in views]\n", + " global_views = views[:2]\n", + " teacher_out = [self.forward_teacher(view) for view in global_views]\n", + " student_out = [self.forward(view) for view in views]\n", + " loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)\n", + " return loss\n", + "\n", + " def on_after_backward(self):\n", + " self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.Adam(self.parameters(), lr=0.001)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINO()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/fastsiam.ipynb b/examples/notebooks/pytorch_lightning/fastsiam.ipynb new file mode 100644 index 000000000..d539d1beb --- /dev/null +++ b/examples/notebooks/pytorch_lightning/fastsiam.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import FastSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class FastSiam(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " features = [self.forward(view) for view in views]\n", + " zs = torch.stack([z for z, _ in features])\n", + " ps = torch.stack([p for _, p in features])\n", + "\n", + " loss = 0.0\n", + " for i in range(len(views)):\n", + " mask = torch.arange(len(views), device=self.device) != i\n", + " loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)\n", + "\n", + " self.log(\"train_loss_ssl\", loss)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = FastSiam()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = FastSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/ijepa.ipynb b/examples/notebooks/pytorch_lightning/ijepa.ipynb new file mode 100644 index 000000000..613c4b28e --- /dev/null +++ b/examples/notebooks/pytorch_lightning/ijepa.ipynb @@ -0,0 +1,21 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "TODO, for now please refer to our pure pytorch example" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/mae.ipynb b/examples/notebooks/pytorch_lightning/mae.ipynb new file mode 100644 index 000000000..ac444aa50 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/mae.ipynb @@ -0,0 +1,204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_base_patch32_224\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM\n", + "from lightly.transforms import MAETransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class MAE(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " decoder_dim = 512\n", + " vit = vit_base_patch32_224()\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + " self.backbone = MaskedVisionTransformerTIMM(vit=vit)\n", + " self.sequence_length = self.backbone.sequence_length\n", + " self.decoder = MAEDecoderTIMM(\n", + " num_patches=vit.patch_embed.num_patches,\n", + " patch_size=self.patch_size,\n", + " embed_dim=vit.embed_dim,\n", + " decoder_embed_dim=decoder_dim,\n", + " decoder_depth=1,\n", + " decoder_num_heads=16,\n", + " mlp_ratio=4.0,\n", + " proj_drop_rate=0.0,\n", + " attn_drop_rate=0.0,\n", + " )\n", + " self.criterion = nn.MSELoss()\n", + "\n", + " def forward_encoder(self, images, idx_keep=None):\n", + " return self.backbone.encode(images=images, idx_keep=idx_keep)\n", + "\n", + " def forward_decoder(self, x_encoded, idx_keep, idx_mask):\n", + " # build decoder input\n", + " batch_size = x_encoded.shape[0]\n", + " x_decode = self.decoder.embed(x_encoded)\n", + " x_masked = utils.repeat_token(\n", + " self.decoder.mask_token, (batch_size, self.sequence_length)\n", + " )\n", + " x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))\n", + "\n", + " # decoder forward pass\n", + " x_decoded = self.decoder.decode(x_masked)\n", + "\n", + " # predict pixel values for masked tokens\n", + " x_pred = utils.get_at_index(x_decoded, idx_mask)\n", + " x_pred = self.decoder.predict(x_pred)\n", + " return x_pred\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " images = views[0] # views contains only a single view\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + " x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)\n", + " x_pred = self.forward_decoder(\n", + " x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask\n", + " )\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + "\n", + " loss = self.criterion(x_pred, target)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = MAE()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/mmcr.ipynb b/examples/notebooks/pytorch_lightning/mmcr.ipynb new file mode 100644 index 000000000..6dc4ddd90 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/mmcr.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MMCRLoss\n", + "from lightly.models.modules import MMCRProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.mmcr_transform import MMCRTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MMCR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = MMCRProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = MMCRLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " z_o = [model(x) for x in batch[0]]\n", + " z_m = [model.forward_momentum(x) for x in batch[0]]\n", + "\n", + " # Switch dimensions to (batch_size, k, embedding_size)\n", + " z_o = torch.stack(z_o, dim=1)\n", + " z_m = torch.stack(z_m, dim=1)\n", + "\n", + " loss = self.criterion(z_o, z_m)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = MMCR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/moco.ipynb b/examples/notebooks/pytorch_lightning/moco.ipynb new file mode 100644 index 000000000..e2535e04f --- /dev/null +++ b/examples/notebooks/pytorch_lightning/moco.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import MoCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.moco_transform import MoCoV2Transform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MoCo(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = MoCoProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = NTXentLoss(memory_bank_size=(4096, 128))\n", + "\n", + " def forward(self, x):\n", + " query = self.backbone(x).flatten(start_dim=1)\n", + " query = self.projection_head(query)\n", + " return query\n", + "\n", + " def forward_momentum(self, x):\n", + " key = self.backbone_momentum(x).flatten(start_dim=1)\n", + " key = self.projection_head_momentum(key).detach()\n", + " return key\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " x_query, x_key = batch[0]\n", + " query = self.forward(x_query)\n", + " key = self.forward_momentum(x_key)\n", + " loss = self.criterion(query, key)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = MoCo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MoCoV2Transform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/msn.ipynb b/examples/notebooks/pytorch_lightning/msn.ipynb new file mode 100644 index 000000000..be24c7a44 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/msn.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms.msn_transform import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class MSN(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " # ViT small configuration (ViT-S/16)\n", + " self.mask_ratio = 0.15\n", + " # ViT small configuration (ViT-S/16)\n", + " vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + " )\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " # or use a torchvision ViT backbone:\n", + " # vit = torchvision.models.vit_b_32(pretrained=False)\n", + " # self.backbone = MAEBackbone.from_vit(vit)\n", + " self.projection_head = MSNProjectionHead(384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + " self.criterion = MSNLoss()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)\n", + " utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)\n", + "\n", + " views = batch[0]\n", + " views = [view.to(self.device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = self.backbone(images=targets)\n", + " targets_out = self.projection_head(targets_out)\n", + " anchors_out = self.encode_masked(anchors)\n", + " anchors_focal_out = self.encode_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = self.criterion(anchors_out, targets_out, self.prototypes.data)\n", + " return loss\n", + "\n", + " def encode_masked(self, anchors):\n", + " batch_size, _, _, width = anchors.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=self.device,\n", + " )\n", + " out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)\n", + "\n", + " def configure_optimizers(self):\n", + " params = [\n", + " *list(self.anchor_backbone.parameters()),\n", + " *list(self.anchor_projection_head.parameters()),\n", + " self.prototypes,\n", + " ]\n", + " optim = torch.optim.AdamW(params, lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = MSN()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/nnclr.ipynb b/examples/notebooks/pytorch_lightning/nnclr.ipynb new file mode 100644 index 000000000..280f96e7b --- /dev/null +++ b/examples/notebooks/pytorch_lightning/nnclr.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import (\n", + " NNCLRPredictionHead,\n", + " NNCLRProjectionHead,\n", + " NNMemoryBankModule,\n", + ")\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class NNCLR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = NNCLRProjectionHead(512, 512, 128)\n", + " self.prediction_head = NNCLRPredictionHead(128, 512, 128)\n", + " self.memory_bank = NNMemoryBankModule(size=(4096, 128))\n", + "\n", + " self.criterion = NTXentLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " z0, p0 = self.forward(x0)\n", + " z1, p1 = self.forward(x1)\n", + " z0 = self.memory_bank(z0, update=False)\n", + " z1 = self.memory_bank(z1, update=True)\n", + " loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = NNCLR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/pmsn.ipynb b/examples/notebooks/pytorch_lightning/pmsn.ipynb new file mode 100644 index 000000000..619c7b382 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/pmsn.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import PMSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class PMSN(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " # ViT small configuration (ViT-S/16)\n", + " self.mask_ratio = 0.15\n", + " vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + " )\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " # or use a torchvision ViT backbone:\n", + " # vit = torchvision.models.vit_b_32(pretrained=False)\n", + " # self.backbone = MAEBackbone.from_vit(vit)\n", + " self.projection_head = MSNProjectionHead(384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + " self.criterion = PMSNLoss()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)\n", + " utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)\n", + "\n", + " views = batch[0]\n", + " views = [view.to(self.device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = self.backbone(images=targets)\n", + " targets_out = self.projection_head(targets_out)\n", + " anchors_out = self.encode_masked(anchors)\n", + " anchors_focal_out = self.encode_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = self.criterion(anchors_out, targets_out, self.prototypes.data)\n", + " return loss\n", + "\n", + " def encode_masked(self, anchors):\n", + " batch_size, _, _, width = anchors.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=self.device,\n", + " )\n", + " out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)\n", + "\n", + " def configure_optimizers(self):\n", + " params = [\n", + " *list(self.anchor_backbone.parameters()),\n", + " *list(self.anchor_projection_head.parameters()),\n", + " self.prototypes,\n", + " ]\n", + " optim = torch.optim.AdamW(params, lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = PMSN()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/simclr.ipynb b/examples/notebooks/pytorch_lightning/simclr.ipynb new file mode 100644 index 000000000..762b16aad --- /dev/null +++ b/examples/notebooks/pytorch_lightning/simclr.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SimCLR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimCLRProjectionHead(512, 2048, 2048)\n", + " self.criterion = NTXentLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = SimCLR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/simmim.ipynb b/examples/notebooks/pytorch_lightning/simmim.ipynb new file mode 100644 index 000000000..9b23882fe --- /dev/null +++ b/examples/notebooks/pytorch_lightning/simmim.ipynb @@ -0,0 +1,198 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.transforms.mae_transform import MAETransform # Same transform as MAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class SimMIM(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " vit = torchvision.models.vit_b_32(pretrained=False)\n", + " decoder_dim = vit.hidden_dim\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_size\n", + " self.sequence_length = vit.seq_length\n", + " self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))\n", + "\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + "\n", + " # the decoder is a simple linear layer\n", + " self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3)\n", + "\n", + " # L1 loss as paper suggestion\n", + " self.criterion = nn.L1Loss()\n", + "\n", + " def forward_encoder(self, images, batch_size, idx_mask):\n", + " # pass all the tokens to the encoder, both masked and non masked ones\n", + " return self.backbone.encode(images=images, idx_mask=idx_mask)\n", + "\n", + " def forward_decoder(self, x_encoded):\n", + " return self.decoder(x_encoded)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " images = views[0] # views contains only a single view\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + "\n", + " # Encoding...\n", + " x_encoded = self.forward_encoder(images, batch_size, idx_mask)\n", + " x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)\n", + "\n", + " # Decoding...\n", + " x_out = self.forward_decoder(x_encoded_masked)\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + "\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + "\n", + " loss = self.criterion(x_out, target)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = SimMIM()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=8,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/simsiam.ipynb b/examples/notebooks/pytorch_lightning/simsiam.ipynb new file mode 100644 index 000000000..6ede439da --- /dev/null +++ b/examples/notebooks/pytorch_lightning/simsiam.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import SimSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SimSiam(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " z0, p0 = self.forward(x0)\n", + " z1, p1 = self.forward(x1)\n", + " loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = SimSiam()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/smog.ipynb b/examples/notebooks/pytorch_lightning/smog.ipynb new file mode 100644 index 000000000..57196e254 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/smog.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from sklearn.cluster import KMeans\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly import loss, models\n", + "from lightly.models import utils\n", + "from lightly.models.modules import heads\n", + "from lightly.transforms.smog_transform import SMoGTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class SMoGModel(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " # create a ResNet backbone and remove the classification head\n", + " resnet = models.ResNetGenerator(\"resnet-18\")\n", + " self.backbone = nn.Sequential(\n", + " *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1)\n", + " )\n", + "\n", + " # create a model based on ResNet\n", + " self.projection_head = heads.SMoGProjectionHead(512, 2048, 128)\n", + " self.prediction_head = heads.SMoGPredictionHead(128, 2048, 128)\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " # smog\n", + " self.n_groups = 300\n", + " memory_bank_size = 10000\n", + " self.memory_bank = loss.memory_bank.MemoryBankModule(size=memory_bank_size)\n", + " # create our loss\n", + " group_features = torch.nn.functional.normalize(\n", + " torch.rand(self.n_groups, 128), dim=1\n", + " ).to(self.device)\n", + " self.smog = heads.SMoGPrototypes(group_features=group_features, beta=0.99)\n", + " self.criterion = nn.CrossEntropyLoss()\n", + "\n", + " def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:\n", + " features = features.cpu().numpy()\n", + " kmeans = KMeans(self.n_groups).fit(features)\n", + " clustered = torch.from_numpy(kmeans.cluster_centers_).float()\n", + " clustered = torch.nn.functional.normalize(clustered, dim=1)\n", + " return clustered\n", + "\n", + " def _reset_group_features(self):\n", + " # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n", + " features = self.memory_bank.bank\n", + " group_features = self._cluster_features(features.t())\n", + " self.smog.set_group_features(group_features)\n", + "\n", + " def _reset_momentum_weights(self):\n", + " # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " if self.global_step > 0 and self.global_step % 300 == 0:\n", + " # reset group features and weights every 300 iterations\n", + " self._reset_group_features()\n", + " self._reset_momentum_weights()\n", + " else:\n", + " # update momentum\n", + " utils.update_momentum(self.backbone, self.backbone_momentum, 0.99)\n", + " utils.update_momentum(\n", + " self.projection_head, self.projection_head_momentum, 0.99\n", + " )\n", + "\n", + " (x0, x1) = batch[0]\n", + "\n", + " if batch_idx % 2:\n", + " # swap batches every second iteration\n", + " x0, x1 = x1, x0\n", + "\n", + " x0_features = self.backbone(x0).flatten(start_dim=1)\n", + " x0_encoded = self.projection_head(x0_features)\n", + " x0_predicted = self.prediction_head(x0_encoded)\n", + " x1_features = self.backbone_momentum(x1).flatten(start_dim=1)\n", + " x1_encoded = self.projection_head_momentum(x1_features)\n", + "\n", + " # update group features and get group assignments\n", + " assignments = self.smog.assign_groups(x1_encoded)\n", + " group_features = self.smog.get_updated_group_features(x0_encoded)\n", + " logits = self.smog(x0_predicted, group_features, temperature=0.1)\n", + " self.smog.set_group_features(group_features)\n", + "\n", + " loss = self.criterion(logits, assignments)\n", + "\n", + " # use memory bank to periodically reset the group features with k-means\n", + " self.memory_bank(x0_encoded, update=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " params = (\n", + " list(self.backbone.parameters())\n", + " + list(self.projection_head.parameters())\n", + " + list(self.prediction_head.parameters())\n", + " )\n", + " optim = torch.optim.SGD(\n", + " params,\n", + " lr=0.01,\n", + " momentum=0.9,\n", + " weight_decay=1e-6,\n", + " )\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = SMoGModel()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SMoGTransform(\n", + " crop_sizes=(32, 32),\n", + " crop_counts=(1, 1),\n", + " gaussian_blur_probs=(0.0, 0.0),\n", + " crop_min_scales=(0.2, 0.2),\n", + " crop_max_scales=(1.0, 1.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/swav.ipynb b/examples/notebooks/pytorch_lightning/swav.ipynb new file mode 100644 index 000000000..21b13a73d --- /dev/null +++ b/examples/notebooks/pytorch_lightning/swav.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import SwaVLoss\n", + "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n", + "from lightly.transforms.swav_transform import SwaVTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SwaV(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SwaVProjectionHead(512, 512, 128)\n", + " self.prototypes = SwaVPrototypes(128, n_prototypes=512)\n", + " self.criterion = SwaVLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " x = self.projection_head(x)\n", + " x = nn.functional.normalize(x, dim=1, p=2)\n", + " p = self.prototypes(x)\n", + " return p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " self.prototypes.normalize()\n", + " views = batch[0]\n", + " multi_crop_features = [self.forward(view.to(self.device)) for view in views]\n", + " high_resolution = multi_crop_features[:2]\n", + " low_resolution = multi_crop_features[2:]\n", + " loss = self.criterion(high_resolution, low_resolution)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.Adam(self.parameters(), lr=0.001)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = SwaV()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SwaVTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=128,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/swav_queue.ipynb b/examples/notebooks/pytorch_lightning/swav_queue.ipynb new file mode 100644 index 000000000..744c8cc32 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/swav_queue.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import SwaVLoss\n", + "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n", + "from lightly.models.modules.memory_bank import MemoryBankModule\n", + "from lightly.transforms.swav_transform import SwaVTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SwaV(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SwaVProjectionHead(512, 512, 128)\n", + " self.prototypes = SwaVPrototypes(128, 512, 1)\n", + " self.start_queue_at_epoch = 2\n", + " self.queues = nn.ModuleList(\n", + " [MemoryBankModule(size=(3840, 128)) for _ in range(2)]\n", + " )\n", + " self.criterion = SwaVLoss()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " high_resolution, low_resolution = views[:2], views[2:]\n", + " self.prototypes.normalize()\n", + "\n", + " high_resolution_features = [self._subforward(x) for x in high_resolution]\n", + " low_resolution_features = [self._subforward(x) for x in low_resolution]\n", + "\n", + " high_resolution_prototypes = [\n", + " self.prototypes(x, self.current_epoch) for x in high_resolution_features\n", + " ]\n", + " low_resolution_prototypes = [\n", + " self.prototypes(x, self.current_epoch) for x in low_resolution_features\n", + " ]\n", + " queue_prototypes = self._get_queue_prototypes(high_resolution_features)\n", + " loss = self.criterion(\n", + " high_resolution_prototypes, low_resolution_prototypes, queue_prototypes\n", + " )\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.Adam(self.parameters(), lr=0.001)\n", + " return optim\n", + "\n", + " def _subforward(self, input):\n", + " features = self.backbone(input).flatten(start_dim=1)\n", + " features = self.projection_head(features)\n", + " features = nn.functional.normalize(features, dim=1, p=2)\n", + " return features\n", + "\n", + " @torch.no_grad()\n", + " def _get_queue_prototypes(self, high_resolution_features):\n", + " if len(high_resolution_features) != len(self.queues):\n", + " raise ValueError(\n", + " f\"The number of queues ({len(self.queues)}) should be equal to the number of high \"\n", + " f\"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly.\"\n", + " )\n", + "\n", + " # Get the queue features\n", + " queue_features = []\n", + " for i in range(len(self.queues)):\n", + " _, features = self.queues[i](high_resolution_features[i], update=True)\n", + " # Queue features are in (num_ftrs X queue_length) shape, while the high res\n", + " # features are in (batch_size X num_ftrs). Swap the axes for interoperability.\n", + " features = torch.permute(features, (1, 0))\n", + " queue_features.append(features)\n", + "\n", + " # If loss calculation with queue prototypes starts at a later epoch,\n", + " # just queue the features and return None instead of queue prototypes.\n", + " if (\n", + " self.start_queue_at_epoch > 0\n", + " and self.current_epoch < self.start_queue_at_epoch\n", + " ):\n", + " return None\n", + "\n", + " # Assign prototypes\n", + " queue_prototypes = [\n", + " self.prototypes(x, self.current_epoch) for x in queue_features\n", + " ]\n", + " return queue_prototypes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = SwaV()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SwaVTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=128,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")\n", + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/tico.ipynb b/examples/notebooks/pytorch_lightning/tico.ipynb new file mode 100644 index 000000000..9d39b1553 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/tico.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss.tico_loss import TiCoLoss\n", + "from lightly.models.modules.heads import TiCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class TiCo(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = TiCoProjectionHead(512, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = TiCoLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " x0 = x0.to(self.device)\n", + " x1 = x1.to(self.device)\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward_momentum(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = TiCo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# TiCo uses BYOL augmentations.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/vicreg.ipynb b/examples/notebooks/pytorch_lightning/vicreg.ipynb new file mode 100644 index 000000000..63f8d1613 --- /dev/null +++ b/examples/notebooks/pytorch_lightning/vicreg.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss.vicreg_loss import VICRegLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "## The projection head is the same as the Barlow Twins one\n", + "from lightly.models.modules.heads import VICRegProjectionHead\n", + "from lightly.transforms.vicreg_transform import VICRegTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class VICReg(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = VICRegProjectionHead(\n", + " input_dim=512,\n", + " hidden_dim=2048,\n", + " output_dim=2048,\n", + " num_layers=2,\n", + " )\n", + " self.criterion = VICRegLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = VICReg()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning/vicregl.ipynb b/examples/notebooks/pytorch_lightning/vicregl.ipynb new file mode 100644 index 000000000..7c8c0b3ff --- /dev/null +++ b/examples/notebooks/pytorch_lightning/vicregl.ipynb @@ -0,0 +1,187 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import VICRegLLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "## The global projection head is the same as the Barlow Twins one\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.models.modules.heads import VicRegLLocalProjectionHead\n", + "from lightly.transforms.vicregl_transform import VICRegLTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class VICRegL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + " self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)\n", + " self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))\n", + " self.criterion = VICRegLLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x)\n", + " y = self.average_pool(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)\n", + " z_local = self.local_projection_head(y_local)\n", + " return z, z_local\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " views_and_grids = batch[0]\n", + " views = views_and_grids[: len(views_and_grids) // 2]\n", + " grids = views_and_grids[len(views_and_grids) // 2 :]\n", + " features = [self.forward(view) for view in views]\n", + " loss = self.criterion(\n", + " global_view_features=features[:2],\n", + " global_view_grids=grids[:2],\n", + " local_view_features=features[2:],\n", + " local_view_grids=grids[2:],\n", + " )\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = VICRegL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegLTransform(n_local_views=0)\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/aim.ipynb b/examples/notebooks/pytorch_lightning_distributed/aim.ipynb new file mode 100644 index 000000000..f119dbe0e --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/aim.ipynb @@ -0,0 +1,201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer\n", + "from lightly.transforms import AIMTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class AIM(pl.LightningModule):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " vit = MaskedCausalVisionTransformer(\n", + " img_size=224,\n", + " patch_size=32,\n", + " embed_dim=768,\n", + " depth=12,\n", + " num_heads=12,\n", + " qk_norm=False,\n", + " class_token=False,\n", + " no_embed_class=True,\n", + " )\n", + " utils.initialize_2d_sine_cosine_positional_embedding(\n", + " pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token\n", + " )\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + " self.num_patches = vit.patch_embed.num_patches\n", + "\n", + " self.backbone = vit\n", + " self.projection_head = AIMPredictionHead(\n", + " input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1\n", + " )\n", + "\n", + " self.criterion = nn.MSELoss()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views, targets = batch[0], batch[1]\n", + " images = views[0] # AIM has only a single view\n", + " batch_size = images.shape[0]\n", + "\n", + " mask = utils.random_prefix_mask(\n", + " size=(batch_size, self.num_patches),\n", + " max_prefix_length=self.num_patches - 1,\n", + " device=images.device,\n", + " )\n", + " features = self.backbone.forward_features(images, mask=mask)\n", + " # Add positional embedding before head.\n", + " features = self.backbone._pos_embed(features)\n", + " predictions = self.projection_head(features)\n", + "\n", + " # Convert images to patches and normalize them.\n", + " patches = utils.patchify(images, self.patch_size)\n", + " patches = utils.normalize_mean_var(patches, dim=-1)\n", + "\n", + " loss = self.criterion(predictions, patches)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = AIM()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = AIMTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP on multiple gpus. Distributed sampling is also enabled with\n", + "# replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/barlowtwins.ipynb b/examples/notebooks/pytorch_lightning_distributed/barlowtwins.ipynb new file mode 100644 index 000000000..ae8cb7889 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/barlowtwins.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import BarlowTwinsLoss\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class BarlowTwins(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + "\n", + " # enable gather_distributed to gather features from all gpus\n", + " # before calculating the loss\n", + " self.criterion = BarlowTwinsLoss(gather_distributed=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = BarlowTwins()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# BarlowTwins uses BYOL augmentations.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/byol.ipynb b/examples/notebooks/pytorch_lightning_distributed/byol.ipynb new file mode 100644 index 000000000..a9ddcf704 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/byol.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import BYOLProjectionHead\n", + "from lightly.models.modules.heads import BYOLPredictionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class BYOL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = BYOLProjectionHead(512, 1024, 256)\n", + " self.prediction_head = BYOLPredictionHead(256, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " return p\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " (x0, x1) = batch[0]\n", + " p0 = self.forward(x0)\n", + " z0 = self.forward_momentum(x0)\n", + " p1 = self.forward(x1)\n", + " z1 = self.forward_momentum(x1)\n", + " loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = BYOL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/dcl.ipynb b/examples/notebooks/pytorch_lightning_distributed/dcl.ipynb new file mode 100644 index 000000000..4f8897e1a --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/dcl.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DCLLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class DCL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimCLRProjectionHead(512, 2048, 2048)\n", + "\n", + " # enable gather_distributed to gather features from all gpus\n", + " # before calculating the loss\n", + " self.criterion = DCLLoss(gather_distributed=True)\n", + " # or use the weighted DCLW loss:\n", + " # self.criterion = DCLWLoss(gather_distributed=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = DCL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/densecl.ipynb b/examples/notebooks/pytorch_lightning_distributed/densecl.ipynb new file mode 100644 index 000000000..4d46be72b --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/densecl.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import DenseCLProjectionHead\n", + "from lightly.transforms import DenseCLTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DenseCL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + " self.projection_head_global = DenseCLProjectionHead(512, 512, 128)\n", + " self.projection_head_local = DenseCLProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_global_momentum = copy.deepcopy(\n", + " self.projection_head_global\n", + " )\n", + " self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1))\n", + "\n", + " utils.deactivate_requires_grad(self.backbone_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_global_momentum)\n", + " utils.deactivate_requires_grad(self.projection_head_local_momentum)\n", + "\n", + " self.criterion_global = NTXentLoss(memory_bank_size=(4096, 128))\n", + " self.criterion_local = NTXentLoss(memory_bank_size=(4096, 128))\n", + "\n", + " def forward(self, x):\n", + " query_features = self.backbone(x)\n", + " query_global = self.pool(query_features).flatten(start_dim=1)\n", + " query_global = self.projection_head_global(query_global)\n", + " query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " query_local = self.projection_head_local(query_features)\n", + " # Shapes: (B, H*W, C), (B, D), (B, H*W, D)\n", + " return query_features, query_global, query_local\n", + "\n", + " @torch.no_grad()\n", + " def forward_momentum(self, x):\n", + " key_features = self.backbone(x)\n", + " key_global = self.pool(key_features).flatten(start_dim=1)\n", + " key_global = self.projection_head_global(key_global)\n", + " key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)\n", + " key_local = self.projection_head_local(key_features)\n", + " return key_features, key_global, key_local\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)\n", + " utils.update_momentum(\n", + " model.projection_head_global,\n", + " model.projection_head_global_momentum,\n", + " m=momentum,\n", + " )\n", + " utils.update_momentum(\n", + " model.projection_head_local,\n", + " model.projection_head_local_momentum,\n", + " m=momentum,\n", + " )\n", + "\n", + " x_query, x_key = batch[0]\n", + " query_features, query_global, query_local = self(x_query)\n", + " key_features, key_global, key_local = self.forward_momentum(x_key)\n", + "\n", + " key_local = utils.select_most_similar(query_features, key_features, key_local)\n", + " query_local = query_local.flatten(end_dim=1)\n", + " key_local = key_local.flatten(end_dim=1)\n", + "\n", + " loss_global = self.criterion_global(query_global, key_global)\n", + " loss_local = self.criterion_local(query_local, key_local)\n", + " lambda_ = 0.5\n", + " loss = (1 - lambda_) * loss_global + lambda_ * loss_local\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = DenseCL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DenseCLTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/dino.ipynb b/examples/notebooks/pytorch_lightning_distributed/dino.ipynb new file mode 100644 index 000000000..9b5681eda --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/dino.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import DINOLoss\n", + "from lightly.models.modules import DINOProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.dino_transform import DINOTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class DINO(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " input_dim = 512\n", + " # instead of a resnet you can also use a vision transformer backbone as in the\n", + " # original paper (you might have to reduce the batch size in this case):\n", + " # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)\n", + " # input_dim = backbone.embed_dim\n", + "\n", + " self.student_backbone = backbone\n", + " self.student_head = DINOProjectionHead(\n", + " input_dim, 512, 64, 2048, freeze_last_layer=1\n", + " )\n", + " self.teacher_backbone = copy.deepcopy(backbone)\n", + " self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)\n", + " deactivate_requires_grad(self.teacher_backbone)\n", + " deactivate_requires_grad(self.teacher_head)\n", + "\n", + " self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)\n", + "\n", + " def forward(self, x):\n", + " y = self.student_backbone(x).flatten(start_dim=1)\n", + " z = self.student_head(y)\n", + " return z\n", + "\n", + " def forward_teacher(self, x):\n", + " y = self.teacher_backbone(x).flatten(start_dim=1)\n", + " z = self.teacher_head(y)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)\n", + " update_momentum(self.student_head, self.teacher_head, m=momentum)\n", + " views = batch[0]\n", + " views = [view.to(self.device) for view in views]\n", + " global_views = views[:2]\n", + " teacher_out = [self.forward_teacher(view) for view in global_views]\n", + " student_out = [self.forward(view) for view in views]\n", + " loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)\n", + " return loss\n", + "\n", + " def on_after_backward(self):\n", + " self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.Adam(self.parameters(), lr=0.001)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = DINO()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = DINOTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/fastsiam.ipynb b/examples/notebooks/pytorch_lightning_distributed/fastsiam.ipynb new file mode 100644 index 000000000..7db997bd5 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/fastsiam.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import FastSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class FastSiam(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " features = [self.forward(view) for view in views]\n", + " zs = torch.stack([z for z, _ in features])\n", + " ps = torch.stack([p for _, p in features])\n", + "\n", + " loss = 0.0\n", + " for i in range(len(views)):\n", + " mask = torch.arange(len(views), device=self.device) != i\n", + " loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)\n", + "\n", + " self.log(\"train_loss_ssl\", loss)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = FastSiam()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = FastSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/ijepa.ipynb b/examples/notebooks/pytorch_lightning_distributed/ijepa.ipynb new file mode 100644 index 000000000..613c4b28e --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/ijepa.ipynb @@ -0,0 +1,21 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "TODO, for now please refer to our pure pytorch example" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/mae.ipynb b/examples/notebooks/pytorch_lightning_distributed/mae.ipynb new file mode 100644 index 000000000..c8fd6cd7d --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/mae.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"lightly[timm]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from timm.models.vision_transformer import vit_base_patch32_224\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM\n", + "from lightly.transforms import MAETransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class MAE(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " decoder_dim = 512\n", + " vit = vit_base_patch32_224()\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_embed.patch_size[0]\n", + " self.backbone = MaskedVisionTransformerTIMM(vit=vit)\n", + " self.sequence_length = self.backbone.sequence_length\n", + " self.decoder = MAEDecoderTIMM(\n", + " num_patches=vit.patch_embed.num_patches,\n", + " patch_size=self.patch_size,\n", + " embed_dim=vit.embed_dim,\n", + " decoder_embed_dim=decoder_dim,\n", + " decoder_depth=1,\n", + " decoder_num_heads=16,\n", + " mlp_ratio=4.0,\n", + " proj_drop_rate=0.0,\n", + " attn_drop_rate=0.0,\n", + " )\n", + " self.criterion = nn.MSELoss()\n", + "\n", + " def forward_encoder(self, images, idx_keep=None):\n", + " return self.backbone.encode(images=images, idx_keep=idx_keep)\n", + "\n", + " def forward_decoder(self, x_encoded, idx_keep, idx_mask):\n", + " # build decoder input\n", + " batch_size = x_encoded.shape[0]\n", + " x_decode = self.decoder.embed(x_encoded)\n", + " x_masked = utils.repeat_token(\n", + " self.decoder.mask_token, (batch_size, self.sequence_length)\n", + " )\n", + " x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))\n", + "\n", + " # decoder forward pass\n", + " x_decoded = self.decoder.decode(x_masked)\n", + "\n", + " # predict pixel values for masked tokens\n", + " x_pred = utils.get_at_index(x_decoded, idx_mask)\n", + " x_pred = self.decoder.predict(x_pred)\n", + " return x_pred\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " images = views[0] # views contains only a single view\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + " x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)\n", + " x_pred = self.forward_decoder(\n", + " x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask\n", + " )\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + "\n", + " loss = self.criterion(x_pred, target)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = MAE()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP on multiple gpus. Distributed sampling is also enabled with\n", + "# replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/mmcr.ipynb b/examples/notebooks/pytorch_lightning_distributed/mmcr.ipynb new file mode 100644 index 000000000..0bff148da --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/mmcr.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MMCRLoss\n", + "from lightly.models.modules import MMCRProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.mmcr_transform import MMCRTransform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MMCR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = MMCRProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = MMCRLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " z_o = [model(x) for x in batch[0]]\n", + " z_m = [model.forward_momentum(x) for x in batch[0]]\n", + "\n", + " # Switch dimensions to (batch_size, k, embedding_size)\n", + " z_o = torch.stack(z_o, dim=1)\n", + " z_m = torch.stack(z_m, dim=1)\n", + "\n", + " loss = self.criterion(z_o, z_m)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = MMCR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = MMCRTransform(k=8, input_size=32, gaussian_blur=0.0)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "if __name__ == \"__main__\":\n", + " trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + " )\n", + " trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/moco.ipynb b/examples/notebooks/pytorch_lightning_distributed/moco.ipynb new file mode 100644 index 000000000..b669b9790 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/moco.ipynb @@ -0,0 +1,187 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import MoCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.moco_transform import MoCoV2Transform\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class MoCo(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = MoCoProjectionHead(512, 512, 128)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = NTXentLoss(memory_bank_size=(4096, 128))\n", + "\n", + " def forward(self, x):\n", + " query = self.backbone(x).flatten(start_dim=1)\n", + " query = self.projection_head(query)\n", + " return query\n", + "\n", + " def forward_momentum(self, x):\n", + " key = self.backbone_momentum(x).flatten(start_dim=1)\n", + " key = self.projection_head_momentum(key).detach()\n", + " return key\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " x_query, x_key = batch[0]\n", + " query = self.forward(x_query)\n", + " key = self.forward_momentum(x_key)\n", + " loss = self.criterion(query, key)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = MoCo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MoCoV2Transform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/msn.ipynb b/examples/notebooks/pytorch_lightning_distributed/msn.ipynb new file mode 100644 index 000000000..02128f986 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/msn.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import MSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms.msn_transform import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class MSN(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " # ViT small configuration (ViT-S/16)\n", + " self.mask_ratio = 0.15\n", + " # ViT small configuration (ViT-S/16)\n", + " vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + " )\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " # or use a torchvision ViT backbone:\n", + " # vit = torchvision.models.vit_b_32(pretrained=False)\n", + " # self.backbone = MAEBackbone.from_vit(vit)\n", + " self.projection_head = MSNProjectionHead(384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + "\n", + " # set gather_distributed to True for distributed training\n", + " self.criterion = MSNLoss(gather_distributed=True)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)\n", + " utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)\n", + "\n", + " views = batch[0]\n", + " views = [view.to(self.device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = self.backbone(images=targets)\n", + " targets_out = self.projection_head(targets_out)\n", + " anchors_out = self.encode_masked(anchors)\n", + " anchors_focal_out = self.encode_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = self.criterion(anchors_out, targets_out, self.prototypes.data)\n", + " return loss\n", + "\n", + " def encode_masked(self, anchors):\n", + " batch_size, _, _, width = anchors.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=self.device,\n", + " )\n", + " out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)\n", + "\n", + " def configure_optimizers(self):\n", + " params = [\n", + " *list(self.anchor_backbone.parameters()),\n", + " *list(self.anchor_projection_head.parameters()),\n", + " self.prototypes,\n", + " ]\n", + " optim = torch.optim.AdamW(params, lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = MSN()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP on multiple gpus. Distributed sampling is also enabled with\n", + "# replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/nnclr.ipynb b/examples/notebooks/pytorch_lightning_distributed/nnclr.ipynb new file mode 100644 index 000000000..8dddc1399 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/nnclr.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import (\n", + " NNCLRPredictionHead,\n", + " NNCLRProjectionHead,\n", + " NNMemoryBankModule,\n", + ")\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class NNCLR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = NNCLRProjectionHead(512, 512, 128)\n", + " self.prediction_head = NNCLRPredictionHead(128, 512, 128)\n", + " self.memory_bank = NNMemoryBankModule(size=(4096, 128))\n", + "\n", + " self.criterion = NTXentLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " z0, p0 = self.forward(x0)\n", + " z1, p1 = self.forward(x1)\n", + " z0 = self.memory_bank(z0, update=False)\n", + " z1 = self.memory_bank(z1, update=True)\n", + " loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = NNCLR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/pmsn.ipynb b/examples/notebooks/pytorch_lightning_distributed/pmsn.ipynb new file mode 100644 index 000000000..da0975baa --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/pmsn.ipynb @@ -0,0 +1,224 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: The model and training settings do not follow the reference settings\n", + "# from the paper. The settings are chosen such that the example can easily be\n", + "# run on a small dataset with a single GPU.\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import PMSNLoss\n", + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.models.modules.heads import MSNProjectionHead\n", + "from lightly.transforms import MSNTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class PMSN(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " # ViT small configuration (ViT-S/16)\n", + " self.mask_ratio = 0.15\n", + " vit = torchvision.models.VisionTransformer(\n", + " image_size=224,\n", + " patch_size=16,\n", + " num_layers=12,\n", + " num_heads=6,\n", + " hidden_dim=384,\n", + " mlp_dim=384 * 4,\n", + " )\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + " # or use a torchvision ViT backbone:\n", + " # vit = torchvision.models.vit_b_32(pretrained=False)\n", + " # self.backbone = MAEBackbone.from_vit(vit)\n", + " self.projection_head = MSNProjectionHead(384)\n", + "\n", + " self.anchor_backbone = copy.deepcopy(self.backbone)\n", + " self.anchor_projection_head = copy.deepcopy(self.projection_head)\n", + "\n", + " utils.deactivate_requires_grad(self.backbone)\n", + " utils.deactivate_requires_grad(self.projection_head)\n", + "\n", + " self.prototypes = nn.Linear(256, 1024, bias=False).weight\n", + "\n", + " # set gather_distributed to True for distributed training\n", + " self.criterion = PMSNLoss(gather_distributed=True)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)\n", + " utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)\n", + "\n", + " views = batch[0]\n", + " views = [view.to(self.device, non_blocking=True) for view in views]\n", + " targets = views[0]\n", + " anchors = views[1]\n", + " anchors_focal = torch.concat(views[2:], dim=0)\n", + "\n", + " targets_out = self.backbone(images=targets)\n", + " targets_out = self.projection_head(targets_out)\n", + " anchors_out = self.encode_masked(anchors)\n", + " anchors_focal_out = self.encode_masked(anchors_focal)\n", + " anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)\n", + "\n", + " loss = self.criterion(anchors_out, targets_out, self.prototypes.data)\n", + " return loss\n", + "\n", + " def encode_masked(self, anchors):\n", + " batch_size, _, _, width = anchors.shape\n", + " seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2\n", + " idx_keep, _ = utils.random_token_mask(\n", + " size=(batch_size, seq_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=self.device,\n", + " )\n", + " out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)\n", + " return self.anchor_projection_head(out)\n", + "\n", + " def configure_optimizers(self):\n", + " params = [\n", + " *list(self.anchor_backbone.parameters()),\n", + " *list(self.anchor_projection_head.parameters()),\n", + " self.prototypes,\n", + " ]\n", + " optim = torch.optim.AdamW(params, lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = PMSN()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MSNTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=64,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "gpus = torch.cuda.device_count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP on multiple gpus. Distributed sampling is also enabled with\n", + "# replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/simclr.ipynb b/examples/notebooks/pytorch_lightning_distributed/simclr.ipynb new file mode 100644 index 000000000..a5a706258 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/simclr.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NTXentLoss\n", + "from lightly.models.modules import SimCLRProjectionHead\n", + "from lightly.transforms.simclr_transform import SimCLRTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class SimCLR(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimCLRProjectionHead(512, 2048, 2048)\n", + "\n", + " # enable gather_distributed to gather features from all gpus\n", + " # before calculating the loss\n", + " self.criterion = NTXentLoss(gather_distributed=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = SimCLR()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimCLRTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/simmim.ipynb b/examples/notebooks/pytorch_lightning_distributed/simmim.ipynb new file mode 100644 index 000000000..7376f62c9 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/simmim.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.models import utils\n", + "from lightly.models.modules import MaskedVisionTransformerTorchvision\n", + "from lightly.transforms.mae_transform import MAETransform # Same transform as MAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "class SimMIM(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " vit = torchvision.models.vit_b_32(pretrained=False)\n", + " self.mask_ratio = 0.75\n", + " self.patch_size = vit.patch_size\n", + " self.sequence_length = vit.seq_length\n", + " decoder_dim = vit.hidden_dim\n", + "\n", + " self.backbone = MaskedVisionTransformerTorchvision(vit=vit)\n", + "\n", + " # the decoder is a simple linear layer\n", + " self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3)\n", + "\n", + " # L1 loss as paper suggestion\n", + " self.criterion = nn.L1Loss()\n", + "\n", + " def forward_encoder(self, images, batch_size, idx_mask):\n", + " # pass all the tokens to the encoder, both masked and non masked ones\n", + " return self.backbone.encode(images=images, idx_mask=idx_mask)\n", + "\n", + " def forward_decoder(self, x_encoded):\n", + " return self.decoder(x_encoded)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " views = batch[0]\n", + " images = views[0] # views contains only a single view\n", + " batch_size = images.shape[0]\n", + " idx_keep, idx_mask = utils.random_token_mask(\n", + " size=(batch_size, self.sequence_length),\n", + " mask_ratio=self.mask_ratio,\n", + " device=images.device,\n", + " )\n", + "\n", + " # Encoding...\n", + " x_encoded = self.forward_encoder(images, batch_size, idx_mask)\n", + " x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)\n", + "\n", + " # Decoding...\n", + " x_out = self.forward_decoder(x_encoded_masked)\n", + "\n", + " # get image patches for masked tokens\n", + " patches = utils.patchify(images, self.patch_size)\n", + "\n", + " # must adjust idx_mask for missing class token\n", + " target = utils.get_at_index(patches, idx_mask - 1)\n", + "\n", + " loss = self.criterion(x_out, target)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "model = SimMIM()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = MAETransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=8,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP on multiple gpus. Distributed sampling is also enabled with\n", + "# replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/simsiam.ipynb b/examples/notebooks/pytorch_lightning_distributed/simsiam.ipynb new file mode 100644 index 000000000..9110a3d46 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/simsiam.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import NegativeCosineSimilarity\n", + "from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead\n", + "from lightly.transforms import SimSiamTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SimSiam(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SimSiamProjectionHead(512, 512, 128)\n", + " self.prediction_head = SimSiamPredictionHead(128, 64, 128)\n", + " self.criterion = NegativeCosineSimilarity()\n", + "\n", + " def forward(self, x):\n", + " f = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(f)\n", + " p = self.prediction_head(z)\n", + " z = z.detach()\n", + " return z, p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " z0, p0 = self.forward(x0)\n", + " z1, p1 = self.forward(x1)\n", + " loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "resnet = torchvision.models.resnet18()\n", + "backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + "model = SimSiam()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SimSiamTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/swav.ipynb b/examples/notebooks/pytorch_lightning_distributed/swav.ipynb new file mode 100644 index 000000000..22bee7268 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/swav.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import SwaVLoss\n", + "from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes\n", + "from lightly.transforms.swav_transform import SwaVTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class SwaV(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = SwaVProjectionHead(512, 512, 128)\n", + " self.prototypes = SwaVPrototypes(128, n_prototypes=512)\n", + "\n", + " # enable sinkhorn_gather_distributed to gather features from all gpus\n", + " # while running the sinkhorn algorithm in the loss calculation\n", + " self.criterion = SwaVLoss(sinkhorn_gather_distributed=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " x = self.projection_head(x)\n", + " x = nn.functional.normalize(x, dim=1, p=2)\n", + " p = self.prototypes(x)\n", + " return p\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " self.prototypes.normalize()\n", + " views = batch[0]\n", + " multi_crop_features = [self.forward(view.to(self.device)) for view in views]\n", + " high_resolution = multi_crop_features[:2]\n", + " low_resolution = multi_crop_features[2:]\n", + " loss = self.criterion(high_resolution, low_resolution)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.Adam(self.parameters(), lr=0.001)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = SwaV()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "transform = SwaVTransform()\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=128,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/tico.ipynb b/examples/notebooks/pytorch_lightning_distributed/tico.ipynb new file mode 100644 index 000000000..100881e3d --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/tico.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss.tico_loss import TiCoLoss\n", + "from lightly.models.modules.heads import TiCoProjectionHead\n", + "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", + "from lightly.transforms.byol_transform import (\n", + " BYOLTransform,\n", + " BYOLView1Transform,\n", + " BYOLView2Transform,\n", + ")\n", + "from lightly.utils.scheduler import cosine_schedule" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "class TiCo(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = TiCoProjectionHead(512, 1024, 256)\n", + "\n", + " self.backbone_momentum = copy.deepcopy(self.backbone)\n", + " self.projection_head_momentum = copy.deepcopy(self.projection_head)\n", + "\n", + " deactivate_requires_grad(self.backbone_momentum)\n", + " deactivate_requires_grad(self.projection_head_momentum)\n", + "\n", + " self.criterion = TiCoLoss()\n", + "\n", + " def forward(self, x):\n", + " y = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " return z\n", + "\n", + " def forward_momentum(self, x):\n", + " y = self.backbone_momentum(x).flatten(start_dim=1)\n", + " z = self.projection_head_momentum(y)\n", + " z = z.detach()\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " (x0, x1) = batch[0]\n", + " momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n", + " update_momentum(self.backbone, self.backbone_momentum, m=momentum)\n", + " update_momentum(self.projection_head, self.projection_head_momentum, m=momentum)\n", + " x0 = x0.to(self.device)\n", + " x1 = x1.to(self.device)\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward_momentum(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.parameters(), lr=0.06)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "model = TiCo()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# TiCo uses BYOL augmentations.\n", + "# We disable resizing and gaussian blur for cifar10.\n", + "transform = BYOLTransform(\n", + " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + ")\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/vicreg.ipynb b/examples/notebooks/pytorch_lightning_distributed/vicreg.ipynb new file mode 100644 index 000000000..39e4cff10 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/vicreg.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import VICRegLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "## The projection head is the same as the Barlow Twins one\n", + "from lightly.models.modules.heads import VICRegProjectionHead\n", + "from lightly.transforms.vicreg_transform import VICRegTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class VICReg(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-1])\n", + " self.projection_head = VICRegProjectionHead(\n", + " input_dim=512,\n", + " hidden_dim=2048,\n", + " output_dim=2048,\n", + " num_layers=2,\n", + " )\n", + "\n", + " # enable gather_distributed to gather features from all gpus\n", + " # before calculating the loss\n", + " self.criterion = VICRegLoss(gather_distributed=True)\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x).flatten(start_dim=1)\n", + " z = self.projection_head(x)\n", + " return z\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " (x0, x1) = batch[0]\n", + " z0 = self.forward(x0)\n", + " z1 = self.forward(x1)\n", + " loss = self.criterion(z0, z1)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(self.parameters(), lr=0.06)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = VICReg()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegTransform(input_size=32)\n", + "dataset = torchvision.datasets.CIFAR10(\n", + " \"datasets/cifar10\", download=True, transform=transform\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/pytorch_lightning_distributed/vicregl.ipynb b/examples/notebooks/pytorch_lightning_distributed/vicregl.ipynb new file mode 100644 index 000000000..a2890f835 --- /dev/null +++ b/examples/notebooks/pytorch_lightning_distributed/vicregl.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "This example requires the following dependencies to be installed:\n", + "pip install lightly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install lightly" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "Note: The model and training settings do not follow the reference settings\n", + "from the paper. The settings are chosen such that the example can easily be\n", + "run on a small dataset with a single GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from lightly.loss import VICRegLLoss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "## The global projection head is the same as the Barlow Twins one\n", + "from lightly.models.modules import BarlowTwinsProjectionHead\n", + "from lightly.models.modules.heads import VicRegLLocalProjectionHead\n", + "from lightly.transforms.vicregl_transform import VICRegLTransform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "class VICRegL(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " resnet = torchvision.models.resnet18()\n", + " self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n", + " self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)\n", + " self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)\n", + " self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))\n", + " self.criterion = VICRegLLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.backbone(x)\n", + " y = self.average_pool(x).flatten(start_dim=1)\n", + " z = self.projection_head(y)\n", + " y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)\n", + " z_local = self.local_projection_head(y_local)\n", + " return z, z_local\n", + "\n", + " def training_step(self, batch, batch_index):\n", + " views_and_grids = batch[0]\n", + " views = views_and_grids[: len(views_and_grids) // 2]\n", + " grids = views_and_grids[len(views_and_grids) // 2 :]\n", + " features = [self.forward(view) for view in views]\n", + " loss = self.criterion(\n", + " global_view_features=features[:2],\n", + " global_view_grids=grids[:2],\n", + " local_view_features=features[2:],\n", + " local_view_grids=grids[2:],\n", + " )\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n", + " return optim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "model = VICRegL()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "transform = VICRegLTransform(n_local_views=0)\n", + "# we ignore object detection annotations by setting target_transform to return 0\n", + "dataset = torchvision.datasets.VOCDetection(\n", + " \"datasets/pascal_voc\",\n", + " download=True,\n", + " transform=transform,\n", + " target_transform=lambda t: 0,\n", + ")\n", + "# or create a dataset from a folder containing images or videos:\n", + "# dataset = LightlyDataset(\"path/to/folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=256,\n", + " shuffle=True,\n", + " drop_last=True,\n", + " num_workers=8,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm\n", + "# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.\n", + "trainer = pl.Trainer(\n", + " max_epochs=10,\n", + " devices=\"auto\",\n", + " accelerator=\"gpu\",\n", + " strategy=\"ddp\",\n", + " sync_batchnorm=True,\n", + " use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0\n", + ")\n", + "trainer.fit(model=model, train_dataloaders=dataloader)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pytorch/aim.py b/examples/pytorch/aim.py index 43d3108d0..9bcb319f3 100644 --- a/examples/pytorch/aim.py +++ b/examples/pytorch/aim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install "lightly[timm]" + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/barlowtwins.py b/examples/pytorch/barlowtwins.py index c6631518d..03901fe8c 100644 --- a/examples/pytorch/barlowtwins.py +++ b/examples/pytorch/barlowtwins.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/byol.py b/examples/pytorch/byol.py index 4abcdf0a0..ff3afa252 100644 --- a/examples/pytorch/byol.py +++ b/examples/pytorch/byol.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/dcl.py b/examples/pytorch/dcl.py index 4acc1c6bb..ddb6b6501 100644 --- a/examples/pytorch/dcl.py +++ b/examples/pytorch/dcl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/densecl.py b/examples/pytorch/densecl.py index fbbe59f37..e65e7ea2c 100644 --- a/examples/pytorch/densecl.py +++ b/examples/pytorch/densecl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/dino.py b/examples/pytorch/dino.py index c5076cfa4..9c36f5e26 100644 --- a/examples/pytorch/dino.py +++ b/examples/pytorch/dino.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/fastsiam.py b/examples/pytorch/fastsiam.py index 12cbb46ed..18571d462 100644 --- a/examples/pytorch/fastsiam.py +++ b/examples/pytorch/fastsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/ijepa.py b/examples/pytorch/ijepa.py index eb4730e04..d5511cf1b 100644 --- a/examples/pytorch/ijepa.py +++ b/examples/pytorch/ijepa.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly[timm] + import copy import torch diff --git a/examples/pytorch/mae.py b/examples/pytorch/mae.py index 2fffacfb9..8e84d7296 100644 --- a/examples/pytorch/mae.py +++ b/examples/pytorch/mae.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly[timm] + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/mmcr.py b/examples/pytorch/mmcr.py index 194f95f22..7e8d66d55 100644 --- a/examples/pytorch/mmcr.py +++ b/examples/pytorch/mmcr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/moco.py b/examples/pytorch/moco.py index 613dd4a20..eb9ec3c2c 100644 --- a/examples/pytorch/moco.py +++ b/examples/pytorch/moco.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/msn.py b/examples/pytorch/msn.py index f69eb7dae..42239bdc7 100644 --- a/examples/pytorch/msn.py +++ b/examples/pytorch/msn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/nnclr.py b/examples/pytorch/nnclr.py index 530e18153..dd854dbda 100644 --- a/examples/pytorch/nnclr.py +++ b/examples/pytorch/nnclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/pmsn.py b/examples/pytorch/pmsn.py index 545e38fc4..3fde2fd47 100644 --- a/examples/pytorch/pmsn.py +++ b/examples/pytorch/pmsn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/simclr.py b/examples/pytorch/simclr.py index 2655e8384..04e878f0a 100644 --- a/examples/pytorch/simclr.py +++ b/examples/pytorch/simclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/simmim.py b/examples/pytorch/simmim.py index 74d8b7f97..d61223d55 100644 --- a/examples/pytorch/simmim.py +++ b/examples/pytorch/simmim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import torch import torchvision from torch import nn diff --git a/examples/pytorch/simsiam.py b/examples/pytorch/simsiam.py index 18d196f84..ba255ba55 100644 --- a/examples/pytorch/simsiam.py +++ b/examples/pytorch/simsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/smog.py b/examples/pytorch/smog.py index 0e59758f7..c96da85cf 100644 --- a/examples/pytorch/smog.py +++ b/examples/pytorch/smog.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/swav.py b/examples/pytorch/swav.py index ed05153b9..e0a079acf 100644 --- a/examples/pytorch/swav.py +++ b/examples/pytorch/swav.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/swav_queue.py b/examples/pytorch/swav_queue.py index a3a148482..eed5bdf8f 100644 --- a/examples/pytorch/swav_queue.py +++ b/examples/pytorch/swav_queue.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py index d2141ec62..841129753 100644 --- a/examples/pytorch/tico.py +++ b/examples/pytorch/tico.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch/vicreg.py b/examples/pytorch/vicreg.py index 39968cb08..49a9683c1 100644 --- a/examples/pytorch/vicreg.py +++ b/examples/pytorch/vicreg.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import torch import torchvision from torch import nn diff --git a/examples/pytorch/vicregl.py b/examples/pytorch/vicregl.py index 082dd1b9d..d9380c0ea 100644 --- a/examples/pytorch/vicregl.py +++ b/examples/pytorch/vicregl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import torch import torchvision from torch import nn diff --git a/examples/pytorch_lightning/aim.py b/examples/pytorch_lightning/aim.py index c871357eb..2c36221be 100644 --- a/examples/pytorch_lightning/aim.py +++ b/examples/pytorch_lightning/aim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install "lightly[timm]" + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/barlowtwins.py b/examples/pytorch_lightning/barlowtwins.py index 86d069e5a..a5ec24418 100644 --- a/examples/pytorch_lightning/barlowtwins.py +++ b/examples/pytorch_lightning/barlowtwins.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/byol.py b/examples/pytorch_lightning/byol.py index a87feb4df..1ac965b05 100644 --- a/examples/pytorch_lightning/byol.py +++ b/examples/pytorch_lightning/byol.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/dcl.py b/examples/pytorch_lightning/dcl.py index ae7e37d13..6afc3a564 100644 --- a/examples/pytorch_lightning/dcl.py +++ b/examples/pytorch_lightning/dcl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/densecl.py b/examples/pytorch_lightning/densecl.py index dd4a5a6f8..7030e2e27 100644 --- a/examples/pytorch_lightning/densecl.py +++ b/examples/pytorch_lightning/densecl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/dino.py b/examples/pytorch_lightning/dino.py index 2db3b3cde..bc9d86db2 100644 --- a/examples/pytorch_lightning/dino.py +++ b/examples/pytorch_lightning/dino.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/fastsiam.py b/examples/pytorch_lightning/fastsiam.py index 9c6d0c29e..6a0f06cdc 100644 --- a/examples/pytorch_lightning/fastsiam.py +++ b/examples/pytorch_lightning/fastsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/ijepa.py b/examples/pytorch_lightning/ijepa.py index 464090415..45bab6f1a 100644 --- a/examples/pytorch_lightning/ijepa.py +++ b/examples/pytorch_lightning/ijepa.py @@ -1 +1 @@ -# TODO +# TODO, for now please refer to our pure pytorch example diff --git a/examples/pytorch_lightning/mae.py b/examples/pytorch_lightning/mae.py index 3c29ae487..97106a910 100644 --- a/examples/pytorch_lightning/mae.py +++ b/examples/pytorch_lightning/mae.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install "lightly[timm]" + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/mmcr.py b/examples/pytorch_lightning/mmcr.py index 0466d4b69..f5bd0841c 100644 --- a/examples/pytorch_lightning/mmcr.py +++ b/examples/pytorch_lightning/mmcr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/moco.py b/examples/pytorch_lightning/moco.py index 7aa32dbbf..30ef6996c 100644 --- a/examples/pytorch_lightning/moco.py +++ b/examples/pytorch_lightning/moco.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/msn.py b/examples/pytorch_lightning/msn.py index 91569b215..85ef7b26e 100644 --- a/examples/pytorch_lightning/msn.py +++ b/examples/pytorch_lightning/msn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/nnclr.py b/examples/pytorch_lightning/nnclr.py index 0a3c462cc..7debe0ca5 100644 --- a/examples/pytorch_lightning/nnclr.py +++ b/examples/pytorch_lightning/nnclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/pmsn.py b/examples/pytorch_lightning/pmsn.py index 7506c3c1d..8faf513cd 100644 --- a/examples/pytorch_lightning/pmsn.py +++ b/examples/pytorch_lightning/pmsn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/simclr.py b/examples/pytorch_lightning/simclr.py index 0c45e988c..753f7bef9 100644 --- a/examples/pytorch_lightning/simclr.py +++ b/examples/pytorch_lightning/simclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/simmim.py b/examples/pytorch_lightning/simmim.py index 00b98b74f..363daee4e 100644 --- a/examples/pytorch_lightning/simmim.py +++ b/examples/pytorch_lightning/simmim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import pytorch_lightning as pl import torch import torchvision diff --git a/examples/pytorch_lightning/simsiam.py b/examples/pytorch_lightning/simsiam.py index 594751430..0272259b8 100644 --- a/examples/pytorch_lightning/simsiam.py +++ b/examples/pytorch_lightning/simsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/smog.py b/examples/pytorch_lightning/smog.py index 245c5e8be..224e3ffed 100644 --- a/examples/pytorch_lightning/smog.py +++ b/examples/pytorch_lightning/smog.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/swav.py b/examples/pytorch_lightning/swav.py index 4b05aef86..404a99ca9 100644 --- a/examples/pytorch_lightning/swav.py +++ b/examples/pytorch_lightning/swav.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/swav_queue.py b/examples/pytorch_lightning/swav_queue.py index 2aad1c44a..6f3bb2f97 100644 --- a/examples/pytorch_lightning/swav_queue.py +++ b/examples/pytorch_lightning/swav_queue.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py index 48f9d0db1..1032ffded 100644 --- a/examples/pytorch_lightning/tico.py +++ b/examples/pytorch_lightning/tico.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import copy import pytorch_lightning as pl diff --git a/examples/pytorch_lightning/vicreg.py b/examples/pytorch_lightning/vicreg.py index 2770309e7..42b0b2bb9 100644 --- a/examples/pytorch_lightning/vicreg.py +++ b/examples/pytorch_lightning/vicreg.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning/vicregl.py b/examples/pytorch_lightning/vicregl.py index 4cc6c7795..6deafdcdc 100644 --- a/examples/pytorch_lightning/vicregl.py +++ b/examples/pytorch_lightning/vicregl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/aim.py b/examples/pytorch_lightning_distributed/aim.py index ef1dd74e8..aa56553b2 100644 --- a/examples/pytorch_lightning_distributed/aim.py +++ b/examples/pytorch_lightning_distributed/aim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install "lightly[timm]" + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/barlowtwins.py b/examples/pytorch_lightning_distributed/barlowtwins.py index 876d7e5f2..4e76b73c6 100644 --- a/examples/pytorch_lightning_distributed/barlowtwins.py +++ b/examples/pytorch_lightning_distributed/barlowtwins.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/byol.py b/examples/pytorch_lightning_distributed/byol.py index 6fb9d88fe..1ffc0cb83 100644 --- a/examples/pytorch_lightning_distributed/byol.py +++ b/examples/pytorch_lightning_distributed/byol.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/dcl.py b/examples/pytorch_lightning_distributed/dcl.py index ecf6e7bcd..678f721e6 100644 --- a/examples/pytorch_lightning_distributed/dcl.py +++ b/examples/pytorch_lightning_distributed/dcl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/densecl.py b/examples/pytorch_lightning_distributed/densecl.py index 5e5e40dbe..e381461d9 100644 --- a/examples/pytorch_lightning_distributed/densecl.py +++ b/examples/pytorch_lightning_distributed/densecl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/dino.py b/examples/pytorch_lightning_distributed/dino.py index b81cdf1a8..df993d056 100644 --- a/examples/pytorch_lightning_distributed/dino.py +++ b/examples/pytorch_lightning_distributed/dino.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/fastsiam.py b/examples/pytorch_lightning_distributed/fastsiam.py index 0556d98d7..55c66698f 100644 --- a/examples/pytorch_lightning_distributed/fastsiam.py +++ b/examples/pytorch_lightning_distributed/fastsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/ijepa.py b/examples/pytorch_lightning_distributed/ijepa.py index 464090415..45bab6f1a 100644 --- a/examples/pytorch_lightning_distributed/ijepa.py +++ b/examples/pytorch_lightning_distributed/ijepa.py @@ -1 +1 @@ -# TODO +# TODO, for now please refer to our pure pytorch example diff --git a/examples/pytorch_lightning_distributed/mae.py b/examples/pytorch_lightning_distributed/mae.py index d4414d210..6dbe60783 100644 --- a/examples/pytorch_lightning_distributed/mae.py +++ b/examples/pytorch_lightning_distributed/mae.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install "lightly[timm]" + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/mmcr.py b/examples/pytorch_lightning_distributed/mmcr.py index 73f0a1530..348ece7c7 100644 --- a/examples/pytorch_lightning_distributed/mmcr.py +++ b/examples/pytorch_lightning_distributed/mmcr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/moco.py b/examples/pytorch_lightning_distributed/moco.py index 6a59a096a..57453cfeb 100644 --- a/examples/pytorch_lightning_distributed/moco.py +++ b/examples/pytorch_lightning_distributed/moco.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/msn.py b/examples/pytorch_lightning_distributed/msn.py index 8b5c565ae..edcf4fdbd 100644 --- a/examples/pytorch_lightning_distributed/msn.py +++ b/examples/pytorch_lightning_distributed/msn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/nnclr.py b/examples/pytorch_lightning_distributed/nnclr.py index d925265bd..a1771da3d 100644 --- a/examples/pytorch_lightning_distributed/nnclr.py +++ b/examples/pytorch_lightning_distributed/nnclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/pmsn.py b/examples/pytorch_lightning_distributed/pmsn.py index 7af776096..73dcf8c0a 100644 --- a/examples/pytorch_lightning_distributed/pmsn.py +++ b/examples/pytorch_lightning_distributed/pmsn.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/simclr.py b/examples/pytorch_lightning_distributed/simclr.py index 457795a5e..fd71c8a27 100644 --- a/examples/pytorch_lightning_distributed/simclr.py +++ b/examples/pytorch_lightning_distributed/simclr.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import pytorch_lightning as pl import torch import torchvision diff --git a/examples/pytorch_lightning_distributed/simmim.py b/examples/pytorch_lightning_distributed/simmim.py index 3cfb6f028..492bdfd6e 100644 --- a/examples/pytorch_lightning_distributed/simmim.py +++ b/examples/pytorch_lightning_distributed/simmim.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import pytorch_lightning as pl import torch import torchvision diff --git a/examples/pytorch_lightning_distributed/simsiam.py b/examples/pytorch_lightning_distributed/simsiam.py index 049784f9e..21e559a5d 100644 --- a/examples/pytorch_lightning_distributed/simsiam.py +++ b/examples/pytorch_lightning_distributed/simsiam.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/swav.py b/examples/pytorch_lightning_distributed/swav.py index 6fcaa8a9b..e91b83e42 100644 --- a/examples/pytorch_lightning_distributed/swav.py +++ b/examples/pytorch_lightning_distributed/swav.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py index 5ed02bad2..ce4ecf9c2 100644 --- a/examples/pytorch_lightning_distributed/tico.py +++ b/examples/pytorch_lightning_distributed/tico.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + import copy import pytorch_lightning as pl diff --git a/examples/pytorch_lightning_distributed/vicreg.py b/examples/pytorch_lightning_distributed/vicreg.py index bc8aa38ad..4e15107a8 100644 --- a/examples/pytorch_lightning_distributed/vicreg.py +++ b/examples/pytorch_lightning_distributed/vicreg.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/examples/pytorch_lightning_distributed/vicregl.py b/examples/pytorch_lightning_distributed/vicregl.py index 67c418463..7ad8bf241 100644 --- a/examples/pytorch_lightning_distributed/vicregl.py +++ b/examples/pytorch_lightning_distributed/vicregl.py @@ -1,3 +1,6 @@ +# This example requires the following dependencies to be installed: +# pip install lightly + # Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. diff --git a/pyproject.toml b/pyproject.toml index 4f77393ae..5b083643d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,9 @@ dev = [ "isort==5.11.5", # frozen version to avoid differences between CI and local dev machines "mypy==1.4.1", # frozen version to avoid differences between CI and local dev machines "types-python-dateutil", - "types-toml" + "types-toml", + "nbformat", + "jupytext" ] # Minimal dependencies against which we test. Older versions might work depending on the # functionality used. @@ -350,3 +352,7 @@ module = [ "lightly.openapi_generated.*", ] ignore_errors = true + +[tool.jupytext] +notebook_metadata_filter="-all" +cell_metadata_filter="-all"