Skip to content

Commit

Permalink
Refactor layers for CLIP text encoder of SD model (#30)
Browse files Browse the repository at this point in the history
* Refactor layers for CLIP text encoder of SD model

* Update comments for return values of model loader.

* Remove shared gate feedforward, which was due to a wrong implementation of quick GELU.

* Remove SharedGatedFeedForward

* Reformat loader.py
  • Loading branch information
yichunk authored Jun 5, 2024
1 parent 475607a commit ffc6b9c
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 358 deletions.
132 changes: 83 additions & 49 deletions ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,99 @@

import torch
from torch import nn
from torch._prims_common import mask_tensor
from torch._prims_common.wrappers import out_wrapper

from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attention_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg
import ai_edge_torch.generative.utilities.loader as loading_utils

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="layers.{}.linear_1",
ff_down_proj="layers.{}.linear_2",
ff_gate_proj="layers.{}.linear_1",
attn_fused_qkv_proj="layers.{}.attention.in_proj",
attn_output_proj="layers.{}.attention.out_proj",
pre_attn_norm="layers.{}.layernorm_1",
pre_ff_norm="layers.{}.layernorm_2",
embedding="embedding.token_embedding",
embedding_position="embedding.position_value",
final_norm="layernorm",
lm_head=None,
)


class CLIPEmbedding(nn.Module):

def __init__(self, n_vocab: int, n_embd: int, n_token: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_embd)
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))

def forward(self, tokens):
x = self.token_embedding(tokens)
x += self.position_value
return x


class CLIPLayer(nn.Module):
class CLIP(nn.Module):
"""CLIP text encoder
For details, see https://arxiv.org/abs/2103.00020
"""

def __init__(self, n_head: int, n_embd: int):
def __init__(self, config: cfg.ModelConfig):
super().__init__()
self.layernorm_1 = nn.LayerNorm(n_embd)
self.attention = SelfAttention(n_head, n_embd)
self.layernorm_2 = nn.LayerNorm(n_embd)
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
self.linear_2 = nn.Linear(4 * n_embd, n_embd)

def forward(self, x):
residue = x
x = self.layernorm_1(x)
x = self.attention(x, causal_mask=True)
x += residue
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
self.tok_embedding_position = nn.Parameter(
torch.zeros((config.max_seq_len, config.embedding_dim))
)

residue = x
x = self.layernorm_2(x)
x = self.linear_1(x)
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
x = self.linear_2(x)
x += residue
self.config = config
self.transformer_blocks = nn.ModuleList(
TransformerBlock(config) for _ in range(config.num_layers)
)
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)

return x


class CLIP(nn.Module):

def __init__(self):
super().__init__()
self.embedding = CLIPEmbedding(49408, 768, 77)
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
self.layernorm = nn.LayerNorm(768)
self.mask_cache = attention_utils.build_causal_mask_cache(
size=config.max_seq_len, dtype=torch.float32
)

@torch.inference_mode
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)

state = self.embedding(tokens)
for layer in self.layers:
state = layer(state)
output = self.layernorm(state)
state = self.tok_embedding(tokens) + self.tok_embedding_position
for layer in self.transformer_blocks:
state = layer(state, mask=self.mask_cache)
output = self.final_norm(state)
return output


def get_model_config() -> cfg.ModelConfig:
max_seq_len = 77
vocab_size = 49408
num_layers = 12
num_heads = 12
num_query_groups = 12
embedding_dim = 768

attn_config = cfg.AttentionConfig(
num_heads=num_heads,
num_query_groups=num_query_groups,
rotary_percentage=0.0,
qkv_use_bias=True,
qkv_transpose_before_split=True,
output_proj_use_bias=True,
enable_kv_cache=False,
)

ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.GELU_QUICK,
intermediate_size=embedding_dim * 4,
use_bias=True,
)

norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)

config = cfg.ModelConfig(
vocab_size=vocab_size,
num_layers=num_layers,
max_seq_len=max_seq_len,
embedding_dim=embedding_dim,
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
pre_ff_norm_config=norm_config,
final_norm_config=norm_config,
enable_hlfb=True,
)

return config
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import torch

import ai_edge_torch
from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
import ai_edge_torch.generative.examples.stable_diffusion.util as util
import ai_edge_torch.generative.utilities.loader as loading_utils


@torch.inference_mode
Expand All @@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
image_width: int = 512,
):

clip = CLIP()
clip.load_state_dict(torch.load(clip_ckpt_path))
clip_model = clip.CLIP(clip.get_model_config())
loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
loader.load(clip_model, strict=False)

encoder = Encoder()
encoder.load_state_dict(torch.load(encoder_ckpt_path))
Expand All @@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
)

input_latents = encoder(input_image, noise)
context_cond = clip(prompt_tokens)
context_cond = clip_model(prompt_tokens)
context_uncond = torch.zeros_like(context_cond)
context = torch.cat([context_cond, context_uncond], axis=0)
time_embedding = util.get_time_embedding(timestamp)

# CLIP text encoder
ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export(
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
'/tmp/stable_diffusion/clip.tflite'
)

Expand Down
Loading

0 comments on commit ffc6b9c

Please sign in to comment.