From 1b4d80f93a2cc0c9dd4797959afa613aac9d029b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 24 Oct 2023 10:32:23 -0700 Subject: [PATCH] allow for local attention to attend to nothing --- audiolm_pytorch/soundstream.py | 13 +++++++++++-- audiolm_pytorch/version.py | 2 +- setup.py | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/audiolm_pytorch/soundstream.py b/audiolm_pytorch/soundstream.py index e6c1f15..50fdef9 100644 --- a/audiolm_pytorch/soundstream.py +++ b/audiolm_pytorch/soundstream.py @@ -396,7 +396,16 @@ def __init__( for _ in range(depth): self.layers.append(nn.ModuleList([ - LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs), + LocalMHA( + dim = dim, + heads = heads, + qk_rmsnorm = True, + window_size = window_size, + use_rotary_pos_emb = not dynamic_pos_bias, + gate_values_per_head = True, + use_xpos = True, + **kwargs + ), FeedForward(dim = dim) ])) @@ -610,7 +619,7 @@ def __init__( self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight - self.register_buffer('zero', torch.tensor([0.]), persistent = False) + self.register_buffer('zero', torch.tensor(0.), persistent = False) @property def device(self): diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 4a9b978..4574cc8 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.2' +__version__ = '1.6.3' diff --git a/setup.py b/setup.py index 18f8cf9..b08c6d4 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ 'fairseq', 'joblib', 'lion-pytorch', - 'local-attention>=1.8.4', + 'local-attention>=1.9.0', 'scikit-learn', 'sentencepiece', 'torch>=1.12',