diff --git a/code/dnn_marathon.ipynb b/code/dnn_marathon.ipynb index d2af45f..bd2ec4d 100644 --- a/code/dnn_marathon.ipynb +++ b/code/dnn_marathon.ipynb @@ -921,21 +921,259 @@ "source": [ "architecture of transformer\n", "\n", - "reference: https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c" + "reference: https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c\n", + "\n", + "![image.png](figures/vit.png)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import numpy as np\n", + "\n", + "from tqdm import tqdm, trange\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Adam\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchvision.transforms import ToTensor\n", + "from torchvision.datasets.mnist import MNIST\n", + "\n", + "np.random.seed(0)\n", + "torch.manual_seed(0)\n", + "\n", + "def patchify(images, n_patches):\n", + " n, c, h, w = images.shape\n", + "\n", + " assert h == w, \"Patchify method is implemented for square images only\"\n", + "\n", + " patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)\n", + " patch_size = h // n_patches\n", + "\n", + " for idx, image in enumerate(images):\n", + " for i in range(n_patches):\n", + " for j in range(n_patches):\n", + " patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]\n", + " patches[idx, i * n_patches + j] = patch.flatten()\n", + " return patches\n", + "\n", + "def get_positional_embeddings(sequence_length, d):\n", + " result = torch.ones(sequence_length, d)\n", + " for i in range(sequence_length):\n", + " for j in range(d):\n", + " result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))\n", + " return result\n", + "\n", + "class MyMSA(nn.Module):\n", + " def __init__(self, d, n_heads=2):\n", + " super(MyMSA, self).__init__()\n", + " self.d = d\n", + " self.n_heads = n_heads\n", + "\n", + " assert d % n_heads == 0, f\"Can't divide dimension {d} into {n_heads} heads\"\n", + "\n", + " d_head = int(d / n_heads)\n", + " self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n", + " self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n", + " self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n", + " self.d_head = d_head\n", + " self.softmax = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, sequences):\n", + " # Sequences has shape (N, seq_length, token_dim)\n", + " # We go into shape (N, seq_length, n_heads, token_dim / n_heads)\n", + " # And come back to (N, seq_length, item_dim) (through concatenation)\n", + " result = []\n", + " for sequence in sequences:\n", + " seq_result = []\n", + " for head in range(self.n_heads):\n", + " q_mapping = self.q_mappings[head]\n", + " k_mapping = self.k_mappings[head]\n", + " v_mapping = self.v_mappings[head]\n", + "\n", + " seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]\n", + " q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)\n", + "\n", + " attention = self.softmax(q @ k.T / (self.d_head ** 0.5))\n", + " seq_result.append(attention @ v)\n", + " result.append(torch.hstack(seq_result))\n", + " return torch.cat([torch.unsqueeze(r, dim=0) for r in result])\n", + "\n", + "\n", + "class MyViTBlock(nn.Module):\n", + " def __init__(self, hidden_d, n_heads, mlp_ratio=4):\n", + " super(MyViTBlock, self).__init__()\n", + " self.hidden_d = hidden_d\n", + " self.n_heads = n_heads\n", + "\n", + " self.norm1 = nn.LayerNorm(hidden_d)\n", + " self.mhsa = MyMSA(hidden_d, n_heads)\n", + " self.norm2 = nn.LayerNorm(hidden_d)\n", + " self.mlp = nn.Sequential(\n", + " nn.Linear(hidden_d, mlp_ratio * hidden_d),\n", + " nn.GELU(),\n", + " nn.Linear(mlp_ratio * hidden_d, hidden_d)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " out = x + self.mhsa(self.norm1(x))\n", + " out = out + self.mlp(self.norm2(out))\n", + " return out\n", + "\n", + "class MyViT(nn.Module):\n", + " def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):\n", + " # Super constructor\n", + " super(MyViT, self).__init__()\n", + " \n", + " # Attributes\n", + " self.chw = chw # ( C , H , W )\n", + " self.n_patches = n_patches\n", + " self.n_blocks = n_blocks\n", + " self.n_heads = n_heads\n", + " self.hidden_d = hidden_d\n", + " \n", + " # Input and patches sizes\n", + " assert chw[1] % n_patches == 0, \"Input shape not entirely divisible by number of patches\"\n", + " assert chw[2] % n_patches == 0, \"Input shape not entirely divisible by number of patches\"\n", + " self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)\n", + "\n", + " # 1) Linear mapper\n", + " self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])\n", + " self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)\n", + " \n", + " # 2) Learnable classification token\n", + " self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))\n", + " \n", + " # 3) Positional embedding\n", + " self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)\n", + " \n", + " # 4) Transformer encoder blocks\n", + " self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])\n", + " \n", + " # 5) Classification MLPk\n", + " self.mlp = nn.Sequential(\n", + " nn.Linear(self.hidden_d, out_d),\n", + " nn.Softmax(dim=-1)\n", + " )\n", + "\n", + " def forward(self, images):\n", + " # Dividing images into patches\n", + " n, c, h, w = images.shape\n", + " patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)\n", + " \n", + " # Running linear layer tokenization\n", + " # Map the vector corresponding to each patch to the hidden size dimension\n", + " tokens = self.linear_mapper(patches)\n", + " \n", + " # Adding classification token to the tokens\n", + " tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)\n", + " \n", + " # Adding positional embedding\n", + " out = tokens + self.positional_embeddings.repeat(n, 1, 1)\n", + " \n", + " # Transformer Blocks\n", + " for block in self.blocks:\n", + " out = block(out)\n", + " \n", + " # Getting the classification token only\n", + " out = out[:, 0]\n", + " \n", + " return self.mlp(out) # Map to output dimension, output category distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Linear-1 [-1, 49, 8] 24,584\n", + " LayerNorm-2 [-1, 50, 8] 16\n", + " Linear-3 [-1, 4] 20\n", + " Linear-4 [-1, 4] 20\n", + " Linear-5 [-1, 4] 20\n", + " Softmax-6 [-1, 50] 0\n", + " Linear-7 [-1, 4] 20\n", + " Linear-8 [-1, 4] 20\n", + " Linear-9 [-1, 4] 20\n", + " Softmax-10 [-1, 50] 0\n", + " Linear-11 [-1, 4] 20\n", + " Linear-12 [-1, 4] 20\n", + " Linear-13 [-1, 4] 20\n", + " Softmax-14 [-1, 50] 0\n", + " Linear-15 [-1, 4] 20\n", + " Linear-16 [-1, 4] 20\n", + " Linear-17 [-1, 4] 20\n", + " Softmax-18 [-1, 50] 0\n", + " MyMSA-19 [-1, 50, 8] 0\n", + " LayerNorm-20 [-1, 50, 8] 16\n", + " Linear-21 [-1, 50, 32] 288\n", + " GELU-22 [-1, 50, 32] 0\n", + " Linear-23 [-1, 50, 8] 264\n", + " MyViTBlock-24 [-1, 50, 8] 0\n", + " LayerNorm-25 [-1, 50, 8] 16\n", + " Linear-26 [-1, 4] 20\n", + " Linear-27 [-1, 4] 20\n", + " Linear-28 [-1, 4] 20\n", + " Softmax-29 [-1, 50] 0\n", + " Linear-30 [-1, 4] 20\n", + " Linear-31 [-1, 4] 20\n", + " Linear-32 [-1, 4] 20\n", + " Softmax-33 [-1, 50] 0\n", + " Linear-34 [-1, 4] 20\n", + " Linear-35 [-1, 4] 20\n", + " Linear-36 [-1, 4] 20\n", + " Softmax-37 [-1, 50] 0\n", + " Linear-38 [-1, 4] 20\n", + " Linear-39 [-1, 4] 20\n", + " Linear-40 [-1, 4] 20\n", + " Softmax-41 [-1, 50] 0\n", + " MyMSA-42 [-1, 50, 8] 0\n", + " LayerNorm-43 [-1, 50, 8] 16\n", + " Linear-44 [-1, 50, 32] 288\n", + " GELU-45 [-1, 50, 32] 0\n", + " Linear-46 [-1, 50, 8] 264\n", + " MyViTBlock-47 [-1, 50, 8] 0\n", + " Linear-48 [-1, 10] 90\n", + " Softmax-49 [-1, 10] 0\n", + "================================================================\n", + "Total params: 26,322\n", + "Trainable params: 26,322\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.57\n", + "Forward/backward pass size (MB): 0.09\n", + "Params size (MB): 0.10\n", + "Estimated Total Size (MB): 0.76\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "from torchsummary import summary\n", + "#print(model)\n", + "summary(model=MyViT((3, 224, 224), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10), input_size=(3,224,224), device='cpu')" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "architecture of transunet" + "architecture of transunet\n", + "\n", + "reference: https://github.com/mkara44/transunet_pytorch" ] }, { @@ -1116,6 +1354,24 @@ "### Keras 3 examples" ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.0.1\n" + ] + } + ], + "source": [ + "import keras\n", + "print(keras.__version__)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1140,7 +1396,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/code/figures/vit.png b/code/figures/vit.png new file mode 100644 index 0000000..6f6c937 Binary files /dev/null and b/code/figures/vit.png differ