Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate stable diffusion example to ai-torch-edge #14

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
106 changes: 106 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import math

import torch
from torch import _decomp
from torch import nn
from torch._prims_common import mask_tensor
from torch._prims_common.wrappers import out_wrapper
from torch.nn import functional as F


def triu(a):
h, w = a.shape[-2:]
mask = (
torch.arange(w, device=a.device).unsqueeze(-2)
- torch.arange(h, device=a.device).unsqueeze(-1)
) >= 1
mask = torch.broadcast_to(mask, a.shape)
return torch.ops.aten.logical_and(a, mask).contiguous()


# _decomp.decomposition_table[torch.ops.aten.triu.default] = triu


class SelfAttention(nn.Module):

def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads

def forward(self, x, causal_mask=False):
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

q, k, v = self.in_proj(x).chunk(3, dim=-1)

q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)

weight = q @ k.transpose(-1, -2)
if causal_mask:
# mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
mask = triu(torch.ones_like(weight, dtype=torch.bool))
weight.masked_fill_(mask, -torch.inf)
weight /= math.sqrt(self.d_head)
weight = F.softmax(weight, dim=-1)

output = weight @ v
output = output.transpose(1, 2)
output = output.reshape(input_shape)
output = self.out_proj(output)
return output


class CrossAttention(nn.Module):

def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads

def forward(self, x, y):
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
interim_shape = (batch_size, -1, self.n_heads, self.d_head)

q = self.q_proj(x)
k = self.k_proj(y)
v = self.v_proj(y)

q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)

weight = q @ k.transpose(-1, -2)
weight /= math.sqrt(self.d_head)
weight = F.softmax(weight, dim=-1)

output = weight @ v
output = output.transpose(1, 2).contiguous()
output = output.view(input_shape)
output = self.out_proj(output)
return output
78 changes: 78 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
from torch import nn
from torch._prims_common import mask_tensor
from torch._prims_common.wrappers import out_wrapper

from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA


class CLIPEmbedding(nn.Module):

def __init__(self, n_vocab: int, n_embd: int, n_token: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_embd)
self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))

def forward(self, tokens):
x = self.token_embedding(tokens)
x += self.position_value
return x


class CLIPLayer(nn.Module):

def __init__(self, n_head: int, n_embd: int):
super().__init__()
self.layernorm_1 = nn.LayerNorm(n_embd)
self.attention = SelfAttention(n_head, n_embd)
self.layernorm_2 = nn.LayerNorm(n_embd)
self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
self.linear_2 = nn.Linear(4 * n_embd, n_embd)

def forward(self, x):
residue = x
x = self.layernorm_1(x)
x = self.attention(x, causal_mask=True)
x += residue

residue = x
x = self.layernorm_2(x)
x = self.linear_1(x)
x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
x = self.linear_2(x)
x += residue

return x


class CLIP(nn.Module):

def __init__(self):
super().__init__()
self.embedding = CLIPEmbedding(49408, 768, 77)
self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
self.layernorm = nn.LayerNorm(768)

def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
tokens = tokens.type(torch.long)

state = self.embedding(tokens)
for layer in self.layers:
state = layer(state)
output = self.layernorm(state)
return output
111 changes: 111 additions & 0 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from torch import nn
from torch.nn import functional as F

from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA


class AttentionBlock(nn.Module):

def __init__(self, channels):
super().__init__()
self.groupnorm = nn.GroupNorm(32, channels)
self.attention = SelfAttention(1, channels)

def forward(self, x):
residue = x
x = self.groupnorm(x)

n, c, h, w = x.shape
x = x.view((n, c, h * w))
x = x.transpose(-1, -2)
x = self.attention(x)
x = x.transpose(-1, -2)
x = x.view((n, c, h, w))

x += residue
return x


class ResidualBlock(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

self.groupnorm_2 = nn.GroupNorm(32, out_channels)
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
self.residual_layer = nn.Conv2d(
in_channels, out_channels, kernel_size=1, padding=0
)

def forward(self, x):
residue = x

x = self.groupnorm_1(x)
x = F.silu(x)
x = self.conv_1(x)

x = self.groupnorm_2(x)
x = F.silu(x)
x = self.conv_2(x)

return x + self.residual_layer(residue)


class Decoder(nn.Sequential):

def __init__(self):
super().__init__(
nn.Conv2d(4, 4, kernel_size=1, padding=0),
nn.Conv2d(4, 512, kernel_size=3, padding=1),
ResidualBlock(512, 512),
AttentionBlock(512),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
nn.Upsample(scale_factor=2),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
nn.Upsample(scale_factor=2),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
ResidualBlock(512, 256),
ResidualBlock(256, 256),
ResidualBlock(256, 256),
nn.Upsample(scale_factor=2),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
ResidualBlock(256, 128),
ResidualBlock(128, 128),
ResidualBlock(128, 128),
nn.GroupNorm(32, 128),
nn.SiLU(),
nn.Conv2d(128, 3, kernel_size=3, padding=1),
)

def forward(self, x):
x = x / 0.18215
for module in self:
x = module(x)
return x
Loading
Loading