Skip to content

Commit

Permalink
Add GeGLU activation and change activation type to activation config …
Browse files Browse the repository at this point in the history
…for model_config. (#42)
  • Loading branch information
yichunk authored Jun 10, 2024
1 parent fe2afd5 commit 1bc5098
Show file tree
Hide file tree
Showing 14 changed files with 67 additions and 32 deletions.
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/gemma/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.GELU_TANH,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
intermediate_size=16384,
)
norm_config = cfg.NormalizationConfig(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.GELU_TANH,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
intermediate_size=10240,
use_bias=True,
)
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/stable_diffusion/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_model_config() -> cfg.ModelConfig:

ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.GELU_QUICK,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
intermediate_size=embedding_dim * 4,
use_bias=True,
)
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/examples/stable_diffusion/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
in_channels=prev_output_channel,
out_channels=block_out_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
num_layers=config.layers_per_block,
add_upsample=not_final_block,
upsample_conv=True,
Expand All @@ -235,7 +235,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
self.final_norm = layers_builder.build_norm(
block_out_channels, config.normalization_config
)
self.act_fn = layers_builder.get_activation(config.activation_type)
self.act_fn = layers_builder.get_activation(config.activation_config)
self.conv_out = nn.Conv2d(
block_out_channels,
config.out_channels,
Expand Down Expand Up @@ -287,7 +287,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
mid_block_config = unet_cfg.MidBlock2DConfig(
in_channels=block_out_channels[-1],
normalization_config=norm_config,
activation_type=layers_cfg.ActivationType.SILU,
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
num_layers=1,
attention_block_config=att_config,
)
Expand All @@ -296,7 +296,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
in_channels=in_channels,
latent_channels=latent_channels,
out_channels=out_channels,
activation_type=layers_cfg.ActivationType.SILU,
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
block_out_channels=block_out_channels,
scaling_factor=scaling_factor,
layers_per_block=layers_per_block,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, channels):
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)


Expand Down Expand Up @@ -237,3 +237,8 @@ def forward(self, latent, context, time):
output = self.unet(latent, context, time)
output = self.final(output)
return output


if __name__ == "__main__":
diffusion = Diffusion()
print(diffusion.state_dict().keys())
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.RELU,
activation=cfg.ActivationConfig(cfg.ActivationType.RELU),
intermediate_size=3072,
)
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/test_models/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def define_and_run() -> None:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_model_config() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_model_config() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=5632,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
40 changes: 31 additions & 9 deletions ai_edge_torch/generative/layers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# Builder class for individual components.
import torch
from torch import nn
import torch.nn.functional as F

Expand All @@ -21,6 +22,23 @@
import ai_edge_torch.generative.layers.normalization as normalization


class GeGLU(nn.Module):
"""GeGLU is an activation function which is a variant of GELU.
GeGLU(x) = (xW+b) * GELU(xV+c)
See: https://arxiv.org/abs/2002.05202v1
"""

def __init__(self, d_in: int, d_out: int):
super().__init__()
self.proj = nn.Linear(d_in, d_out * 2)

def forward(self, x: torch.Tensor):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)


def build_norm(dim: int, config: cfg.NormalizationConfig):
"""Builder function for normalizers.
Expand Down Expand Up @@ -81,29 +99,33 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
)


def get_activation(type_: cfg.ActivationType):
"""Get pytorch callable activation from the name.
def get_activation(config: cfg.ActivationConfig):
"""Get pytorch callable activation from the activation config.
Args:
name (string): activation's name.
config (cfg.ActivationConfig): activation config.
Returns:
Activation function.
Raises:
ValueError: If activation name is not supported.
ValueError: If activation config is not supported.
"""
if type_ == cfg.ActivationType.SILU:
if config.type == cfg.ActivationType.LINEAR:
return lambda x: x
elif config.type == cfg.ActivationType.SILU:
return F.silu
elif type_ == cfg.ActivationType.GELU:
elif config.type == cfg.ActivationType.GELU:
return F.gelu
elif type_ == cfg.ActivationType.GELU_TANH:
elif config.type == cfg.ActivationType.GELU_TANH:
return lambda x: F.gelu(x, approximate="tanh")
elif type_ == cfg.ActivationType.GELU_QUICK:
elif config.type == cfg.ActivationType.GELU_QUICK:
# GELU approximation that is fast but somewhat inaccurate.
# See: https://github.com/hendrycks/GELUs
return lambda x: x * F.sigmoid(1.702 * x)
elif type_ == cfg.ActivationType.RELU:
elif config.type == cfg.ActivationType.GE_GLU:
return GeGLU(config.dim_in, config.dim_out)
elif config.type == cfg.ActivationType.RELU:
return F.relu
else:
raise ValueError("Unsupported activation type.")
12 changes: 10 additions & 2 deletions ai_edge_torch/generative/layers/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ActivationType(enum.Enum):
SILU = enum.auto()
GELU = enum.auto()
GELU_TANH = enum.auto()
GELU_QUICK = enum.auto()
GE_GLU = enum.auto()
RELU = enum.auto()


Expand Down Expand Up @@ -74,12 +74,20 @@ class AttentionConfig:
relative_attention_max_distance: int = 0


@dataclass
class ActivationConfig:
type: ActivationType = ActivationType.LINEAR
# Dimension of input and output, used in GeGLU.
dim_in: Optional[int] = None
dim_out: Optional[int] = None


@dataclass
class FeedForwardConfig:
"""FeedForward module's parameters."""

type: FeedForwardType
activation: ActivationType
activation: ActivationConfig
intermediate_size: int
use_bias: bool = False

Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
self.conv_2 = nn.Conv2d(
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
)
self.act_fn = layers_builder.get_activation(config.activation_type)
self.act_fn = layers_builder.get_activation(config.activation_config)
if config.in_channels == config.out_channels:
self.residual_layer = nn.Identity()
else:
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
out_channels=config.out_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
)
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig):
out_channels=config.in_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
]
Expand All @@ -259,7 +259,7 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig):
out_channels=config.in_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/layers/unet/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ResidualBlock2DConfig:
in_channels: int
out_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
# Optional time embedding channels if the residual block takes a time embedding context as input
time_embedding_channels: Optional[int] = None

Expand All @@ -56,7 +56,7 @@ class UpDecoderBlock2DConfig:
in_channels: int
out_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
num_layers: int
# Optional time embedding channels if the residual blocks take a time embedding context as input
time_embedding_channels: Optional[int] = None
Expand All @@ -72,7 +72,7 @@ class UpDecoderBlock2DConfig:
class MidBlock2DConfig:
in_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
num_layers: int
# Optional time embedding channels if the residual blocks take a time embedding context as input
time_embedding_channels: Optional[int] = None
Expand All @@ -85,7 +85,7 @@ class AutoEncoderConfig:
"""Configurations of encoder/decoder in the autoencoder model."""

# The activation type of encoder/decoder blocks.
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig

# The output channels of each block.
block_out_channels: List[int]
Expand Down

0 comments on commit 1bc5098

Please sign in to comment.