Skip to content

Commit

Permalink
Move attention head dim to config
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Sep 24, 2024
1 parent 97c646c commit 1e4ffe6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/transformers/models/zamba/configuration_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
intermediate_size=14848,
num_hidden_layers=76,
num_attention_heads=16,
attention_head_dim=None,
num_key_value_heads=None,
n_mamba_heads=2,
hidden_act="gelu",
Expand Down Expand Up @@ -160,6 +161,10 @@ def __init__(
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if attention_head_dim is None:
self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
else:
self.attention_head_dim = attention_head_dim
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):

self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = 2 * self.hidden_size // self.num_heads
self.head_dim = config.attention_head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
Expand Down

0 comments on commit 1e4ffe6

Please sign in to comment.