Skip to content

Commit

Permalink
fix open_clip library change default transformer behavior to `batch…
Browse files Browse the repository at this point in the history
…_first=True`
  • Loading branch information
FlorianFuerrutter committed Jul 16, 2024
1 parent 3e47ff7 commit 4879009
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 6,197 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# genQC · Generative Quantum Circuits


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
<p align="left">
<a><img src="https://badgen.net/badge/icon/awesome?icon=awesome&label" alt="awesome"></a>
Expand Down
38 changes: 23 additions & 15 deletions genQC/models/frozen_open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cpu", m

self.model = model
self.to(device)


self.tokenizer = open_clip.get_tokenizer(arch)

assert max_length <= 77 # max set by the clip
self.max_length = max_length

Expand Down Expand Up @@ -68,7 +70,8 @@ def to(self, device):

@torch.no_grad()
def tokenize_and_push_to_device(self, text, to_device=True):
tokens = open_clip.tokenize(text)
# tokens = open_clip.tokenize(text)
tokens = self.tokenizer(text)
if to_device:
tokens = tokens.to(self.device)
return tokens
Expand All @@ -79,25 +82,30 @@ def forward(self, c, **kwargs):

@torch.no_grad()
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding[None, :x.shape[1]]
x = x.permute(1, 0, 2) # NLD -> LND
cast_dtype = self.model.transformer.get_cast_dtype()

x = self.model.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding[None, :x.shape[1]].to(cast_dtype)

if not self.model.transformer.batch_first:
x = x.permute(1, 0, 2) # NLD -> LND

x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)

if not self.model.transformer.batch_first:
x = x.permute(1, 0, 2) # LND -> NLD

x = self.model.ln_final(x) # [batch_size, n_ctx, transformer.width]

return x

@torch.no_grad()
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
#if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
#x = checkpoint(r, x, attn_mask)
#else:

x = r(x, attn_mask=attn_mask)


x = r(x, attn_mask=attn_mask)
return x

#--------------------------------------------------------------
Expand All @@ -113,7 +121,7 @@ def from_config(config, device: torch.device, save_path: str=None):
config["save_path"] = None
return Config_Model.from_config(config, device, save_path=None)

# %% ../../src/models/frozen_open_clip.ipynb 13
# %% ../../src/models/frozen_open_clip.ipynb 17
class CachedFrozenOpenCLIPEmbedder(FrozenOpenCLIPEmbedder):
"""Adds caching support to `FrozenOpenCLIPEmbedder`."""

Expand Down Expand Up @@ -141,7 +149,7 @@ def generate_cache(self, str_list: list=None, tokens=None, cached_empty_token_in

if i == 0:
mem = n * x.shape[1] * x.shape[2] * x.element_size() * 1e-9
print(f"[INFO]: caching trying to allocate memory {(n, x.shape[1], x.shape[2])} on {'cpu' if y_on_cpu else self.device} approx. {mem:.3f} GB")
print(f"[INFO]: caching trying to allocate memory {(n, x.shape[1], x.shape[2])} on {'cpu' if y_on_cpu else self.device}, approx. {mem:.3f} GB")
self.cached_embeddings = torch.zeros((n, x.shape[1], x.shape[2]), device="cpu" if y_on_cpu else self.device, dtype=x.dtype) # alloc huge memory !!

self.cached_embeddings[last_ind:last_ind+x.shape[0]] = x.to(self.cached_embeddings.device)
Expand Down
Binary file modified index_files/figure-commonmark/cell-3-output-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 10 additions & 9 deletions settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ custom_quarto_yml = True
custom_sidebar = True

### PyPI ###
author = Florian Fuerrutter
author_email = [email protected]
copyright = 2023 onwards, %(author)s
audience = Developers
description = Generating quantum circuits with diffusion models
keywords = quantum-information diffusion-model generative-model
language = English
status = 3
requirements = torch numpy matplotlib scipy pandas omegaconf qiskit tqdm joblib open_clip_torch ipywidgets pylatexenc
author = Florian Fuerrutter
author_email = [email protected]
copyright = 2023 onwards, %(author)s
audience = Developers
description = Generating quantum circuits with diffusion models
keywords = quantum-information diffusion-model generative-model
language = English
status = 3
requirements = torch numpy matplotlib scipy pandas omegaconf qiskit tqdm joblib open_clip_torch ipywidgets pylatexenc
dev_requirements = jupyterlab nbdev
225 changes: 15 additions & 210 deletions src/examples/0_hello_circuit.ipynb

Large diffs are not rendered by default.

235 changes: 16 additions & 219 deletions src/examples/1_editing_and_masking.ipynb

Large diffs are not rendered by default.

28 changes: 18 additions & 10 deletions src/examples/2_unitary_compilation.ipynb

Large diffs are not rendered by default.

5,784 changes: 79 additions & 5,705 deletions src/examples/3_dataset_and_fineTune.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions src/index.ipynb

Large diffs are not rendered by default.

131 changes: 114 additions & 17 deletions src/models/frozen_open_clip.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@
" \n",
" self.model = model\n",
" self.to(device)\n",
" \n",
"\n",
" self.tokenizer = open_clip.get_tokenizer(arch)\n",
" \n",
" assert max_length <= 77 # max set by the clip \n",
" self.max_length = max_length\n",
" \n",
Expand Down Expand Up @@ -114,7 +116,8 @@
" \n",
" @torch.no_grad()\n",
" def tokenize_and_push_to_device(self, text, to_device=True):\n",
" tokens = open_clip.tokenize(text)\n",
" # tokens = open_clip.tokenize(text)\n",
" tokens = self.tokenizer(text)\n",
" if to_device:\n",
" tokens = tokens.to(self.device)\n",
" return tokens\n",
Expand All @@ -125,25 +128,30 @@
"\n",
" @torch.no_grad()\n",
" def encode_with_transformer(self, text):\n",
" x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]\n",
" x = x + self.model.positional_embedding[None, :x.shape[1]]\n",
" x = x.permute(1, 0, 2) # NLD -> LND\n",
" cast_dtype = self.model.transformer.get_cast_dtype()\n",
" \n",
" x = self.model.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] \n",
" x = x + self.model.positional_embedding[None, :x.shape[1]].to(cast_dtype)\n",
"\n",
" if not self.model.transformer.batch_first:\n",
" x = x.permute(1, 0, 2) # NLD -> LND\n",
" \n",
" x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)\n",
" x = x.permute(1, 0, 2) # LND -> NLD\n",
" x = self.model.ln_final(x)\n",
"\n",
" if not self.model.transformer.batch_first:\n",
" x = x.permute(1, 0, 2) # LND -> NLD\n",
" \n",
" x = self.model.ln_final(x) # [batch_size, n_ctx, transformer.width]\n",
" \n",
" return x\n",
"\n",
" @torch.no_grad()\n",
" def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):\n",
" for i, r in enumerate(self.model.transformer.resblocks):\n",
" if i == len(self.model.transformer.resblocks) - self.layer_idx:\n",
" break\n",
" #if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():\n",
" #x = checkpoint(r, x, attn_mask)\n",
" #else:\n",
" \n",
" x = r(x, attn_mask=attn_mask)\n",
" \n",
"\n",
" x = r(x, attn_mask=attn_mask) \n",
" return x\n",
"\n",
" #--------------------------------------------------------------\n",
Expand Down Expand Up @@ -215,6 +223,48 @@
"a.tokenize_and_push_to_device(\"\").shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15e7dcf8-5836-48a7-8b21-d118f8f11996",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 77])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.tokenize_and_push_to_device([\"1,1,2\", \"2,2,2\"]).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4902b7cf-1eed-4b82-b4e5-c1ece8bfd416",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([77, 77])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.model.attn_mask.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -231,7 +281,7 @@
" ...,\n",
" [ 0.4703, -1.4072, -0.4847, ..., -0.1257, -0.1650, 0.1206],\n",
" [ 0.5117, -1.3949, -0.4672, ..., -0.4288, -0.2166, 0.2904],\n",
" [ 0.1480, -2.1998, -1.1187, ..., 0.0823, -0.4157, 0.6236]],\n",
" [ 0.1480, -2.1998, -1.1187, ..., 0.0823, -0.4157, 0.6237]],\n",
" \n",
" [[-0.3134, -0.4476, -0.0082, ..., 0.2542, -0.0324, -0.2960],\n",
" [-0.1180, -1.6322, 1.2987, ..., -0.1378, -0.1529, -0.3377],\n",
Expand All @@ -253,6 +303,53 @@
"enc.shape, enc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8efcdfb1-b8c4-44c4-b15a-2900df2d3cb6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[49406, 272, 267, 272, 267, 273, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [49406, 273, 267, 273, 267, 273, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "952cf514-2838-4516-8313-51b4507c8cfd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<start_of_text>2 , 2 , 2 <end_of_text>!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.tokenizer.decode(c[1].tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -325,7 +422,7 @@
" \n",
" if i == 0:\n",
" mem = n * x.shape[1] * x.shape[2] * x.element_size() * 1e-9\n",
" print(f\"[INFO]: caching trying to allocate memory {(n, x.shape[1], x.shape[2])} on {'cpu' if y_on_cpu else self.device} approx. {mem:.3f} GB\")\n",
" print(f\"[INFO]: caching trying to allocate memory {(n, x.shape[1], x.shape[2])} on {'cpu' if y_on_cpu else self.device}, approx. {mem:.3f} GB\")\n",
" self.cached_embeddings = torch.zeros((n, x.shape[1], x.shape[2]), device=\"cpu\" if y_on_cpu else self.device, dtype=x.dtype) # alloc huge memory !!\n",
" \n",
" self.cached_embeddings[last_ind:last_ind+x.shape[0]] = x.to(self.cached_embeddings.device)\n",
Expand Down Expand Up @@ -374,7 +471,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6db77b94925e45d4879fa75bc884dd38",
"model_id": "5a795ff325c540ccbb1407ab880c633f",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -389,7 +486,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO]: caching trying to allocate memory (2, 77, 1024) on cpu approx. 0.001 GB\n"
"[INFO]: caching trying to allocate memory (2, 77, 1024) on cpu, approx. 0.001 GB\n"
]
}
],
Expand Down

0 comments on commit 4879009

Please sign in to comment.