Skip to content

Commit

Permalink
Internal clean up.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698155153
  • Loading branch information
hheydary authored and copybara-github committed Nov 19, 2024
1 parent b476f0a commit 813b700
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions ai_edge_torch/generative/layers/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Model configuration class.
from dataclasses import dataclass
from dataclasses import field

"""Model configuration class."""

import dataclasses
import enum
from typing import Optional, Sequence, Union

Expand All @@ -35,7 +36,7 @@ class ActivationType(enum.Enum):

@enum.unique
class NormalizationType(enum.Enum):
"""Different normalization functions"""
"""Different normalization functions."""

# No normalization is applied.
NONE = enum.auto()
Expand All @@ -59,7 +60,7 @@ class AttentionType(enum.Enum):
LOCAL_SLIDING = enum.auto()


@dataclass
@dataclasses.dataclass
class NormalizationConfig:
"""Normalizater parameters."""

Expand All @@ -71,7 +72,7 @@ class NormalizationConfig:
group_num: Optional[float] = None


@dataclass
@dataclasses.dataclass
class AttentionConfig:
"""Attention model's parameters."""

Expand All @@ -90,18 +91,20 @@ class AttentionConfig:
# Whether to use bias with Query, Key, and Value projection.
qkv_use_bias: bool = False
# Whether the fused q, k, v projection weights interleaves q, k, v heads.
# If True, the projection weights are in format [q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]
# If False, the projection weights are in format [q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]
# If True, the projection weights are in format:
# `[q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]`
# If False, the projection weights are in format:
# `[q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]`
qkv_fused_interleaved: bool = True
# Whether to use bias with attention output projection.
output_proj_use_bias: bool = False
enable_kv_cache: bool = True
# The normalization applied to query projection's output.
query_norm_config: NormalizationConfig = field(
query_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)
# The normalization applied to key projection's output.
key_norm_config: NormalizationConfig = field(
key_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)
relative_attention_num_buckets: int = 0
Expand All @@ -114,15 +117,15 @@ class AttentionConfig:
sliding_window_size: Optional[int] = None


@dataclass
@dataclasses.dataclass
class ActivationConfig:
type: ActivationType = ActivationType.LINEAR
# Dimension of input and output, used in GeGLU.
dim_in: Optional[int] = None
dim_out: Optional[int] = None


@dataclass
@dataclasses.dataclass
class FeedForwardConfig:
"""FeedForward module's parameters."""

Expand All @@ -131,27 +134,27 @@ class FeedForwardConfig:
intermediate_size: int
use_bias: bool = False
# The normalization applied to feed forward's input.
pre_ff_norm_config: NormalizationConfig = field(
pre_ff_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)
# The normalization applied to feed forward's output.
post_ff_norm_config: NormalizationConfig = field(
post_ff_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)


@dataclass
@dataclasses.dataclass
class TransformerBlockConfig:
"""TransformerBlock module's parameters."""

attn_config: AttentionConfig
ff_config: FeedForwardConfig
# The normalization applied to attention's input.
pre_attention_norm_config: NormalizationConfig = field(
pre_attention_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)
# The normalization applied to attentions's output.
post_attention_norm_config: NormalizationConfig = field(
post_attention_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)
# If set to True, only attn_config.pre_attention_norm is applied to the input
Expand All @@ -163,7 +166,7 @@ class TransformerBlockConfig:
relative_attention: bool = False


@dataclass
@dataclasses.dataclass
class ImageEmbeddingConfig:
"""Image embedding parameters."""

Expand All @@ -173,7 +176,7 @@ class ImageEmbeddingConfig:
patch_size: int


@dataclass
@dataclasses.dataclass
class ModelConfig:
"""Base configurations for building a transformer architecture."""

Expand All @@ -187,7 +190,7 @@ class ModelConfig:
block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]

# The normalization applied before LM head.
final_norm_config: NormalizationConfig = field(
final_norm_config: NormalizationConfig = dataclasses.field(
default_factory=NormalizationConfig
)

Expand Down

0 comments on commit 813b700

Please sign in to comment.