Skip to content

Commit

Permalink
allow for local attention to attend to nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 24, 2023
1 parent 8fd178d commit 1b4d80f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,16 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs),
LocalMHA(
dim = dim,
heads = heads,
qk_rmsnorm = True,
window_size = window_size,
use_rotary_pos_emb = not dynamic_pos_bias,
gate_values_per_head = True,
use_xpos = True,
**kwargs
),
FeedForward(dim = dim)
]))

Expand Down Expand Up @@ -610,7 +619,7 @@ def __init__(
self.adversarial_loss_weight = adversarial_loss_weight
self.feature_loss_weight = feature_loss_weight

self.register_buffer('zero', torch.tensor([0.]), persistent = False)
self.register_buffer('zero', torch.tensor(0.), persistent = False)

@property
def device(self):
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.6.2'
__version__ = '1.6.3'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'fairseq',
'joblib',
'lion-pytorch',
'local-attention>=1.8.4',
'local-attention>=1.9.0',
'scikit-learn',
'sentencepiece',
'torch>=1.12',
Expand Down

0 comments on commit 1b4d80f

Please sign in to comment.