Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Jul 19, 2024
1 parent bf84312 commit 091c018
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@
"urllib3<2.0.0",
"uvicorn",
"pytest-rich",
"ninja",
"einops",
"triton",
"causal-conv1d==1.2.0.post2",
"mamba-ssm==2.2.2",
]


Expand Down Expand Up @@ -426,6 +431,11 @@ def run(self):
deps["tokenizers"],
deps["safetensors"],
deps["tqdm"], # progress bars in model download and training scripts
deps["mamba-ssm"],
deps["ninja"],
deps["einops"],
deps["triton"],
deps["causal-conv1d"],
]

setup(
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/zamba2/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
import torch.nn.functional as F
import torch.nn as nn
from .utils import bias_gelu_impl
from .mamba_config import MambaConfig
from .configuration_zamba2 import Zamba2Config

class MLP(nn.Module):

def __init__(self, config: MambaConfig,is_expert: bool = False, layer_idx=None, num_mem_blocks = None):
def __init__(self, config: Zamba2Config,is_expert: bool = False, layer_idx=None, num_mem_blocks = None):
super().__init__()

self.num_mem_blocks = num_mem_blocks

self.config: MambaConfig = config
self.config: Zamba2Config = config
self.layer = layer_idx
ffn_hidden_size_1 = self.config.ffn_hidden_size
ffn_hidden_size_2 = self.config.ffn_hidden_size
Expand Down

0 comments on commit 091c018

Please sign in to comment.