From 4a487501ad4aebec1deea8808d6309fa2a764376 Mon Sep 17 00:00:00 2001 From: Taleir of Deynai Date: Sat, 12 Nov 2022 18:44:20 -0800 Subject: [PATCH 1/2] Adds embedding shuffling, with several modes. --- ldm/modules/embedding_manager.py | 7 +- ldm/modules/embedding_shuffler.py | 129 ++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 ldm/modules/embedding_shuffler.py diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index cbabc41..e1f0f87 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -1,6 +1,8 @@ import torch from torch import nn +from ldm.util import default +from ldm.modules.embedding_shuffler import get_shuffler from ldm.data.personalized import per_img_token_list from transformers import CLIPTokenizer from functools import partial @@ -35,6 +37,7 @@ def __init__( embedder, placeholder_strings=None, initializer_words=None, + shuffle_mode=None, per_image_tokens=False, num_vectors_per_token=1, progressive_words=False, @@ -48,6 +51,7 @@ def __init__( self.initial_embeddings = nn.ParameterDict() # These should not be optimized + self.shuffle_embeddings = get_shuffler(default(shuffle_mode, "off")) self.progressive_words = progressive_words self.progressive_counter = 0 @@ -107,6 +111,7 @@ def forward( max_step_tokens = self.max_vectors_per_token num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + shuffle_view = self.shuffle_embeddings(placeholder_embedding, num_vectors_for_token) placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) @@ -121,7 +126,7 @@ def forward( col = sorted_cols[idx] new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] - new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], shuffle_view, embedded_text[row][col + 1:]], axis=0)[:n] embedded_text[row] = new_embed_row tokenized_text[row] = new_token_row diff --git a/ldm/modules/embedding_shuffler.py b/ldm/modules/embedding_shuffler.py new file mode 100644 index 0000000..2cf7f5c --- /dev/null +++ b/ldm/modules/embedding_shuffler.py @@ -0,0 +1,129 @@ +from typing import Union, Callable, Optional, Literal + +import torch +from torch import Tensor + +from ldm.util import default + +ShuffleMode = Union[ + Literal["off"], + Literal["on", "all"], + Literal["trailing", "leading", "between"], + Literal["progressive", "dynamic"] +] +ShuffleFn = Union[ + Callable[[Tensor], Tensor], + Callable[[Tensor, Optional[int]], Tensor] +] + +def idx_of(value: int, device: torch.device): + """Helper that makes single-value tensors for some device.""" + return torch.tensor([value], dtype=torch.int64, device=device) + +def shuffle_off(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """Performs no shuffling, but will still trim to the number of vectors.""" + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + return placeholder_embedding[:num_vectors] + +def shuffle_all(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """Shuffles all embeddings.""" + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + d = placeholder_embedding.device + if num_vectors >= 2: + trim_source = placeholder_embedding[:num_vectors] + shuffle_idx = torch.randperm(num_vectors, device=d) + return trim_source[shuffle_idx].view(trim_source.size()) + + # No effect with fewer than 2 vectors. + return shuffle_off(placeholder_embedding, num_vectors) + +def shuffle_trailing(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """Shuffles everything after first embedding.""" + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + d = placeholder_embedding.device + if num_vectors >= 3: + trim_source = placeholder_embedding[:num_vectors] + shuffle_idx = torch.randperm(num_vectors - 1, device=d) + 1 + shuffle_idx = torch.cat([idx_of(0, d), shuffle_idx]) + return trim_source[shuffle_idx].view(trim_source.size()) + + # No effect with fewer than 3 vectors. + return shuffle_off(placeholder_embedding, num_vectors) + +def shuffle_leading(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """Shuffles everything before the last embedding.""" + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + d = placeholder_embedding.device + if num_vectors >= 3: + trim_source = placeholder_embedding[:num_vectors] + shuffle_idx = torch.randperm(num_vectors - 1, device=d) + shuffle_idx = torch.cat([shuffle_idx, idx_of(num_vectors - 1, d)]) + return trim_source[shuffle_idx].view(trim_source.size()) + + # No effect with fewer than 3 vectors. + return shuffle_off(placeholder_embedding, num_vectors) + +def shuffle_between(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """Shuffles between the first and last embeddings.""" + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + d = placeholder_embedding.device + if num_vectors >= 4: + trim_source = placeholder_embedding[:num_vectors] + shuffle_idx = torch.randperm(num_vectors - 2, device=d) + 1 + shuffle_idx = torch.cat([idx_of(0, d), shuffle_idx, idx_of(num_vectors - 1, d)]) + return trim_source[shuffle_idx].view(trim_source.size()) + + # No effect with fewer than 4 vectors. + return shuffle_off(placeholder_embedding, num_vectors) + +def shuffle_progressive(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """ + Always includes the first and last embeddings (if `num_vectors` is large enough) + while shuffling the embeddings in between. Unlike `shuffle_dynamic`, this + establishes stable intro and outro embeddings ASAP. + + This was made as an option for progressive words mode. + """ + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + d = placeholder_embedding.device + if num_vectors == 2: + # Only `[, ]`. + last_idx = placeholder_embedding.shape[0] - 1 + shuffle_idx = torch.cat([idx_of(0, d), idx_of(last_idx, d)]) + return placeholder_embedding[shuffle_idx].view(num_vectors, -1) + if num_vectors > 2: + # Now `[, ..., ]` + last_idx = placeholder_embedding.shape[0] - 1 + shuffle_idx = torch.randperm(num_vectors-2, device=d) + 1 + shuffle_idx = torch.cat([idx_of(0, d), shuffle_idx, idx_of(last_idx, d)]) + return placeholder_embedding[shuffle_idx].view(num_vectors, -1) + + # No effect with fewer than 2 vectors. + return shuffle_off(placeholder_embedding, num_vectors) + +def shuffle_dynamic(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): + """ + Tries to always perform an embedding shuffle when possible. + + The type of shuffle done depends on the number of vectors: + * 4 or more uses `between` shuffling. + * 3 uses `trailing` shuffling. + * 2 or less uses `all` shuffling. + """ + num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + if num_vectors >= 4: return shuffle_between(placeholder_embedding, num_vectors) + if num_vectors == 3: return shuffle_trailing(placeholder_embedding, num_vectors) + return shuffle_all(placeholder_embedding, num_vectors) + +def get_shuffler(shuffle_mode: Union[bool, ShuffleMode]) -> ShuffleFn: + if shuffle_mode == True: shuffle_mode = "all" + elif shuffle_mode == "on": shuffle_mode = "all" + elif shuffle_mode == False: shuffle_mode = "off" + + if shuffle_mode == "all": return shuffle_all + if shuffle_mode == "dynamic": return shuffle_dynamic + if shuffle_mode == "progressive": return shuffle_progressive + if shuffle_mode == "between": return shuffle_between + if shuffle_mode == "trailing": return shuffle_trailing + if shuffle_mode == "leading": return shuffle_leading + return shuffle_off \ No newline at end of file From b1df9363247f373c5089f996a00b73c432959e13 Mon Sep 17 00:00:00 2001 From: Taleir of Deynai Date: Wed, 16 Nov 2022 16:30:07 -0800 Subject: [PATCH 2/2] A couple small tweeks. Shuffles the embedding for each prompt in the batch individually. The embedding does remain consistent within each prompt, in case there is a prompt like: "a painting of * smiling beside *" `shuffle_off` no longer performs indexing when `num_vectors` is the same as the max number of vectors. Probably doesn't matter, but it likely gets rid of some needless indirection. --- ldm/modules/embedding_manager.py | 2 +- ldm/modules/embedding_shuffler.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index e1f0f87..a104f94 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -111,7 +111,6 @@ def forward( max_step_tokens = self.max_vectors_per_token num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) - shuffle_view = self.shuffle_embeddings(placeholder_embedding, num_vectors_for_token) placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) @@ -125,6 +124,7 @@ def forward( row = sorted_rows[idx] col = sorted_cols[idx] + shuffle_view = self.shuffle_embeddings(placeholder_embedding, num_vectors_for_token) new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] new_embed_row = torch.cat([embedded_text[row][:col], shuffle_view, embedded_text[row][col + 1:]], axis=0)[:n] diff --git a/ldm/modules/embedding_shuffler.py b/ldm/modules/embedding_shuffler.py index 2cf7f5c..f38d38c 100644 --- a/ldm/modules/embedding_shuffler.py +++ b/ldm/modules/embedding_shuffler.py @@ -23,6 +23,8 @@ def idx_of(value: int, device: torch.device): def shuffle_off(placeholder_embedding: Tensor, num_vectors: Optional[int]=None): """Performs no shuffling, but will still trim to the number of vectors.""" num_vectors = default(num_vectors, placeholder_embedding.shape[0]) + if num_vectors == placeholder_embedding.shape[0]: + return placeholder_embedding return placeholder_embedding[:num_vectors] def shuffle_all(placeholder_embedding: Tensor, num_vectors: Optional[int]=None):