Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GeGLU activation and change activation type to activation config #42

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/gemma/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.GELU_TANH,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
intermediate_size=16384,
)
norm_config = cfg.NormalizationConfig(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/phi2/phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.GELU_TANH,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
intermediate_size=10240,
use_bias=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_model_config() -> cfg.ModelConfig:

ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.GELU_QUICK,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
intermediate_size=embedding_dim * 4,
use_bias=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
in_channels=prev_output_channel,
out_channels=block_out_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
num_layers=config.layers_per_block,
add_upsample=not_final_block,
upsample_conv=True,
Expand All @@ -235,7 +235,7 @@ def __init__(self, config: unet_cfg.AutoEncoderConfig):
self.final_norm = layers_builder.build_norm(
block_out_channels, config.normalization_config
)
self.act_fn = layers_builder.get_activation(config.activation_type)
self.act_fn = layers_builder.get_activation(config.activation_config)
self.conv_out = nn.Conv2d(
block_out_channels,
config.out_channels,
Expand Down Expand Up @@ -287,7 +287,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
mid_block_config = unet_cfg.MidBlock2DConfig(
in_channels=block_out_channels[-1],
normalization_config=norm_config,
activation_type=layers_cfg.ActivationType.SILU,
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
num_layers=1,
attention_block_config=att_config,
)
Expand All @@ -296,7 +296,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
in_channels=in_channels,
latent_channels=latent_channels,
out_channels=out_channels,
activation_type=layers_cfg.ActivationType.SILU,
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
block_out_channels=block_out_channels,
scaling_factor=scaling_factor,
layers_per_block=layers_per_block,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, channels):
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = F.interpolate(x, scale_factor=2, mode="nearest")
return self.conv(x)


Expand Down Expand Up @@ -237,3 +237,8 @@ def forward(self, latent, context, time):
output = self.unet(latent, context, time)
output = self.final(output)
return output


if __name__ == "__main__":
diffusion = Diffusion()
print(diffusion.state_dict().keys())
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationType.RELU,
activation=cfg.ActivationConfig(cfg.ActivationType.RELU),
intermediate_size=3072,
)
# T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def define_and_run() -> None:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_model_config() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_model_config() -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=256,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=5632,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
Expand Down
40 changes: 31 additions & 9 deletions ai_edge_torch/generative/layers/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# Builder class for individual components.
import torch
from torch import nn
import torch.nn.functional as F

Expand All @@ -21,6 +22,23 @@
import ai_edge_torch.generative.layers.normalization as normalization


class GeGLU(nn.Module):
"""GeGLU is an activation function which is a variant of GELU.

GeGLU(x) = (xW+b) * GELU(xV+c)
See: https://arxiv.org/abs/2002.05202v1

"""

def __init__(self, d_in: int, d_out: int):
super().__init__()
self.proj = nn.Linear(d_in, d_out * 2)

def forward(self, x: torch.Tensor):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)


def build_norm(dim: int, config: cfg.NormalizationConfig):
"""Builder function for normalizers.

Expand Down Expand Up @@ -81,29 +99,33 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
)


def get_activation(type_: cfg.ActivationType):
"""Get pytorch callable activation from the name.
def get_activation(config: cfg.ActivationConfig):
"""Get pytorch callable activation from the activation config.

Args:
name (string): activation's name.
config (cfg.ActivationConfig): activation config.

Returns:
Activation function.

Raises:
ValueError: If activation name is not supported.
ValueError: If activation config is not supported.
"""
if type_ == cfg.ActivationType.SILU:
if config.type == cfg.ActivationType.LINEAR:
return lambda x: x
elif config.type == cfg.ActivationType.SILU:
return F.silu
elif type_ == cfg.ActivationType.GELU:
elif config.type == cfg.ActivationType.GELU:
return F.gelu
elif type_ == cfg.ActivationType.GELU_TANH:
elif config.type == cfg.ActivationType.GELU_TANH:
return lambda x: F.gelu(x, approximate="tanh")
elif type_ == cfg.ActivationType.GELU_QUICK:
elif config.type == cfg.ActivationType.GELU_QUICK:
# GELU approximation that is fast but somewhat inaccurate.
# See: https://github.com/hendrycks/GELUs
return lambda x: x * F.sigmoid(1.702 * x)
elif type_ == cfg.ActivationType.RELU:
elif config.type == cfg.ActivationType.GE_GLU:
return GeGLU(config.dim_in, config.dim_out)
elif config.type == cfg.ActivationType.RELU:
return F.relu
else:
raise ValueError("Unsupported activation type.")
12 changes: 10 additions & 2 deletions ai_edge_torch/generative/layers/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ActivationType(enum.Enum):
SILU = enum.auto()
GELU = enum.auto()
GELU_TANH = enum.auto()
GELU_QUICK = enum.auto()
GE_GLU = enum.auto()
RELU = enum.auto()


Expand Down Expand Up @@ -74,12 +74,20 @@ class AttentionConfig:
relative_attention_max_distance: int = 0


@dataclass
class ActivationConfig:
type: ActivationType = ActivationType.LINEAR
# Dimension of input and output, used in GeGLU.
dim_in: Optional[int] = None
dim_out: Optional[int] = None


@dataclass
class FeedForwardConfig:
"""FeedForward module's parameters."""

type: FeedForwardType
activation: ActivationType
activation: ActivationConfig
intermediate_size: int
use_bias: bool = False

Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/layers/unet/blocks_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
self.conv_2 = nn.Conv2d(
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
)
self.act_fn = layers_builder.get_activation(config.activation_type)
self.act_fn = layers_builder.get_activation(config.activation_config)
if config.in_channels == config.out_channels:
self.residual_layer = nn.Identity()
else:
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
out_channels=config.out_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
)
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig):
out_channels=config.in_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
]
Expand All @@ -259,7 +259,7 @@ def __init__(self, config: unet_cfg.MidBlock2DConfig):
out_channels=config.in_channels,
time_embedding_channels=config.time_embedding_channels,
normalization_config=config.normalization_config,
activation_type=config.activation_type,
activation_config=config.activation_config,
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions ai_edge_torch/generative/layers/unet/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ResidualBlock2DConfig:
in_channels: int
out_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
# Optional time embedding channels if the residual block takes a time embedding context as input
time_embedding_channels: Optional[int] = None

Expand All @@ -56,7 +56,7 @@ class UpDecoderBlock2DConfig:
in_channels: int
out_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
num_layers: int
# Optional time embedding channels if the residual blocks take a time embedding context as input
time_embedding_channels: Optional[int] = None
Expand All @@ -72,7 +72,7 @@ class UpDecoderBlock2DConfig:
class MidBlock2DConfig:
in_channels: int
normalization_config: layers_cfg.NormalizationConfig
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig
num_layers: int
# Optional time embedding channels if the residual blocks take a time embedding context as input
time_embedding_channels: Optional[int] = None
Expand All @@ -85,7 +85,7 @@ class AutoEncoderConfig:
"""Configurations of encoder/decoder in the autoencoder model."""

# The activation type of encoder/decoder blocks.
activation_type: layers_cfg.ActivationType
activation_config: layers_cfg.ActivationConfig

# The output channels of each block.
block_out_channels: List[int]
Expand Down
Loading