Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jizhang02 committed Dec 10, 2023
1 parent cb4aa0e commit 6faba65
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 5 deletions.
266 changes: 261 additions & 5 deletions code/dnn_marathon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -921,21 +921,259 @@
"source": [
"architecture of transformer\n",
"\n",
"reference: https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c"
"reference: https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c\n",
"\n",
"![image.png](figures/vit.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import numpy as np\n",
"\n",
"from tqdm import tqdm, trange\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.optim import Adam\n",
"from torch.nn import CrossEntropyLoss\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.datasets.mnist import MNIST\n",
"\n",
"np.random.seed(0)\n",
"torch.manual_seed(0)\n",
"\n",
"def patchify(images, n_patches):\n",
" n, c, h, w = images.shape\n",
"\n",
" assert h == w, \"Patchify method is implemented for square images only\"\n",
"\n",
" patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)\n",
" patch_size = h // n_patches\n",
"\n",
" for idx, image in enumerate(images):\n",
" for i in range(n_patches):\n",
" for j in range(n_patches):\n",
" patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]\n",
" patches[idx, i * n_patches + j] = patch.flatten()\n",
" return patches\n",
"\n",
"def get_positional_embeddings(sequence_length, d):\n",
" result = torch.ones(sequence_length, d)\n",
" for i in range(sequence_length):\n",
" for j in range(d):\n",
" result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))\n",
" return result\n",
"\n",
"class MyMSA(nn.Module):\n",
" def __init__(self, d, n_heads=2):\n",
" super(MyMSA, self).__init__()\n",
" self.d = d\n",
" self.n_heads = n_heads\n",
"\n",
" assert d % n_heads == 0, f\"Can't divide dimension {d} into {n_heads} heads\"\n",
"\n",
" d_head = int(d / n_heads)\n",
" self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n",
" self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n",
" self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])\n",
" self.d_head = d_head\n",
" self.softmax = nn.Softmax(dim=-1)\n",
"\n",
" def forward(self, sequences):\n",
" # Sequences has shape (N, seq_length, token_dim)\n",
" # We go into shape (N, seq_length, n_heads, token_dim / n_heads)\n",
" # And come back to (N, seq_length, item_dim) (through concatenation)\n",
" result = []\n",
" for sequence in sequences:\n",
" seq_result = []\n",
" for head in range(self.n_heads):\n",
" q_mapping = self.q_mappings[head]\n",
" k_mapping = self.k_mappings[head]\n",
" v_mapping = self.v_mappings[head]\n",
"\n",
" seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]\n",
" q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)\n",
"\n",
" attention = self.softmax(q @ k.T / (self.d_head ** 0.5))\n",
" seq_result.append(attention @ v)\n",
" result.append(torch.hstack(seq_result))\n",
" return torch.cat([torch.unsqueeze(r, dim=0) for r in result])\n",
"\n",
"\n",
"class MyViTBlock(nn.Module):\n",
" def __init__(self, hidden_d, n_heads, mlp_ratio=4):\n",
" super(MyViTBlock, self).__init__()\n",
" self.hidden_d = hidden_d\n",
" self.n_heads = n_heads\n",
"\n",
" self.norm1 = nn.LayerNorm(hidden_d)\n",
" self.mhsa = MyMSA(hidden_d, n_heads)\n",
" self.norm2 = nn.LayerNorm(hidden_d)\n",
" self.mlp = nn.Sequential(\n",
" nn.Linear(hidden_d, mlp_ratio * hidden_d),\n",
" nn.GELU(),\n",
" nn.Linear(mlp_ratio * hidden_d, hidden_d)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" out = x + self.mhsa(self.norm1(x))\n",
" out = out + self.mlp(self.norm2(out))\n",
" return out\n",
"\n",
"class MyViT(nn.Module):\n",
" def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):\n",
" # Super constructor\n",
" super(MyViT, self).__init__()\n",
" \n",
" # Attributes\n",
" self.chw = chw # ( C , H , W )\n",
" self.n_patches = n_patches\n",
" self.n_blocks = n_blocks\n",
" self.n_heads = n_heads\n",
" self.hidden_d = hidden_d\n",
" \n",
" # Input and patches sizes\n",
" assert chw[1] % n_patches == 0, \"Input shape not entirely divisible by number of patches\"\n",
" assert chw[2] % n_patches == 0, \"Input shape not entirely divisible by number of patches\"\n",
" self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)\n",
"\n",
" # 1) Linear mapper\n",
" self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])\n",
" self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)\n",
" \n",
" # 2) Learnable classification token\n",
" self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))\n",
" \n",
" # 3) Positional embedding\n",
" self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)\n",
" \n",
" # 4) Transformer encoder blocks\n",
" self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])\n",
" \n",
" # 5) Classification MLPk\n",
" self.mlp = nn.Sequential(\n",
" nn.Linear(self.hidden_d, out_d),\n",
" nn.Softmax(dim=-1)\n",
" )\n",
"\n",
" def forward(self, images):\n",
" # Dividing images into patches\n",
" n, c, h, w = images.shape\n",
" patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)\n",
" \n",
" # Running linear layer tokenization\n",
" # Map the vector corresponding to each patch to the hidden size dimension\n",
" tokens = self.linear_mapper(patches)\n",
" \n",
" # Adding classification token to the tokens\n",
" tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)\n",
" \n",
" # Adding positional embedding\n",
" out = tokens + self.positional_embeddings.repeat(n, 1, 1)\n",
" \n",
" # Transformer Blocks\n",
" for block in self.blocks:\n",
" out = block(out)\n",
" \n",
" # Getting the classification token only\n",
" out = out[:, 0]\n",
" \n",
" return self.mlp(out) # Map to output dimension, output category distribution"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Linear-1 [-1, 49, 8] 24,584\n",
" LayerNorm-2 [-1, 50, 8] 16\n",
" Linear-3 [-1, 4] 20\n",
" Linear-4 [-1, 4] 20\n",
" Linear-5 [-1, 4] 20\n",
" Softmax-6 [-1, 50] 0\n",
" Linear-7 [-1, 4] 20\n",
" Linear-8 [-1, 4] 20\n",
" Linear-9 [-1, 4] 20\n",
" Softmax-10 [-1, 50] 0\n",
" Linear-11 [-1, 4] 20\n",
" Linear-12 [-1, 4] 20\n",
" Linear-13 [-1, 4] 20\n",
" Softmax-14 [-1, 50] 0\n",
" Linear-15 [-1, 4] 20\n",
" Linear-16 [-1, 4] 20\n",
" Linear-17 [-1, 4] 20\n",
" Softmax-18 [-1, 50] 0\n",
" MyMSA-19 [-1, 50, 8] 0\n",
" LayerNorm-20 [-1, 50, 8] 16\n",
" Linear-21 [-1, 50, 32] 288\n",
" GELU-22 [-1, 50, 32] 0\n",
" Linear-23 [-1, 50, 8] 264\n",
" MyViTBlock-24 [-1, 50, 8] 0\n",
" LayerNorm-25 [-1, 50, 8] 16\n",
" Linear-26 [-1, 4] 20\n",
" Linear-27 [-1, 4] 20\n",
" Linear-28 [-1, 4] 20\n",
" Softmax-29 [-1, 50] 0\n",
" Linear-30 [-1, 4] 20\n",
" Linear-31 [-1, 4] 20\n",
" Linear-32 [-1, 4] 20\n",
" Softmax-33 [-1, 50] 0\n",
" Linear-34 [-1, 4] 20\n",
" Linear-35 [-1, 4] 20\n",
" Linear-36 [-1, 4] 20\n",
" Softmax-37 [-1, 50] 0\n",
" Linear-38 [-1, 4] 20\n",
" Linear-39 [-1, 4] 20\n",
" Linear-40 [-1, 4] 20\n",
" Softmax-41 [-1, 50] 0\n",
" MyMSA-42 [-1, 50, 8] 0\n",
" LayerNorm-43 [-1, 50, 8] 16\n",
" Linear-44 [-1, 50, 32] 288\n",
" GELU-45 [-1, 50, 32] 0\n",
" Linear-46 [-1, 50, 8] 264\n",
" MyViTBlock-47 [-1, 50, 8] 0\n",
" Linear-48 [-1, 10] 90\n",
" Softmax-49 [-1, 10] 0\n",
"================================================================\n",
"Total params: 26,322\n",
"Trainable params: 26,322\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 0.09\n",
"Params size (MB): 0.10\n",
"Estimated Total Size (MB): 0.76\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"from torchsummary import summary\n",
"#print(model)\n",
"summary(model=MyViT((3, 224, 224), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10), input_size=(3,224,224), device='cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"architecture of transunet"
"architecture of transunet\n",
"\n",
"reference: https://github.com/mkara44/transunet_pytorch"
]
},
{
Expand Down Expand Up @@ -1116,6 +1354,24 @@
"### Keras 3 examples"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.0.1\n"
]
}
],
"source": [
"import keras\n",
"print(keras.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -1140,7 +1396,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.5"
},
"vscode": {
"interpreter": {
Expand Down
Binary file added code/figures/vit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6faba65

Please sign in to comment.