From fc70d518d3770788d17a5d9799e08d23ad19c525 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 17 May 2023 14:08:12 -0700 Subject: [PATCH] make conformer able to do things autoregressively, to save issues with variable lengths in soundstorm --- conformer/conformer.py | 12 ++++++++---- setup.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/conformer/conformer.py b/conformer/conformer.py index 6762d13..44726b8 100644 --- a/conformer/conformer.py +++ b/conformer/conformer.py @@ -149,7 +149,8 @@ def __init__( causal = False, expansion_factor = 2, kernel_size = 31, - dropout = 0.): + dropout = 0. + ): super().__init__() inner_dim = dim * expansion_factor @@ -185,12 +186,13 @@ def __init__( conv_kernel_size = 31, attn_dropout = 0., ff_dropout = 0., - conv_dropout = 0. + conv_dropout = 0., + conv_causal = False ): super().__init__() self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) - self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) + self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = PreNorm(dim, self.attn) @@ -222,7 +224,8 @@ def __init__( conv_kernel_size = 31, attn_dropout = 0., ff_dropout = 0., - conv_dropout = 0. + conv_dropout = 0., + conv_causal = False ): super().__init__() self.dim = dim @@ -236,6 +239,7 @@ def __init__( ff_mult = ff_mult, conv_expansion_factor = conv_expansion_factor, conv_kernel_size = conv_kernel_size, + conv_causal = conv_causal )) diff --git a/setup.py b/setup.py index 3bd36b0..e6827ba 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'conformer', packages = find_packages(), - version = '0.3.1', + version = '0.3.2', license='MIT', description = 'The convolutional module from the Conformer paper', author = 'Phil Wang',