diff --git a/stable_audio_tools/models/adp.py b/stable_audio_tools/models/adp.py index 3041446..7422779 100644 --- a/stable_audio_tools/models/adp.py +++ b/stable_audio_tools/models/adp.py @@ -77,7 +77,7 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module: def Downsample1d( in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 ) -> nn.Module: - assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + assert kernel_multiplier % 2 == 0, "Kernel multiplier should be divisible by 2." return Conv1d( in_channels=in_channels, @@ -217,7 +217,8 @@ def __init__( ) if self.use_mapping: - assert exists(context_mapping_features) + assert_message = "Ensure that context mapping features exist when use_mapping is set to True." + assert exists(context_mapping_features), assert_message self.to_scale_shift = MappingToScaleShift( features=context_mapping_features, channels=out_channels ) @@ -581,7 +582,7 @@ class LearnedPositionalEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() - assert (dim % 2) == 0 + assert (dim % 2) == 0, "Ensure that the value of 'dim' can be divided by 2." half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) @@ -657,10 +658,11 @@ def __init__( ) if self.use_transformer: + assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer." assert ( exists(attention_heads) and exists(attention_multiplier) - ) + ), assert_message attention_features = default(attention_features, channels // attention_heads) @@ -765,10 +767,11 @@ def __init__( ) if self.use_transformer: + assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer." assert ( exists(attention_heads) and exists(attention_multiplier) - ) + ), assert_message attention_features = default(attention_features, channels // attention_heads) @@ -857,10 +860,11 @@ def __init__( ) if self.use_transformer: + assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer." assert ( exists(attention_heads) and exists(attention_multiplier) - ) + ), assert_message attention_features = default(attention_features, channels // attention_heads) @@ -972,7 +976,8 @@ def __init__( ) if use_context_time: - assert exists(context_mapping_features) + assert_message = "When use_context_time is set to True, context_mapping_features must be also specified." + assert exists(context_mapping_features), assert_message self.to_time = nn.Sequential( TimePositionalEmbedding( dim=channels, out_features=context_mapping_features @@ -981,7 +986,8 @@ def __init__( ) if use_context_features: - assert exists(context_features) and exists(context_mapping_features) + assert_message = "When use_context_features is set to True, context_feature and context_mapping_features must be both specified." + assert exists(context_features) and exists(context_mapping_features), assert_message self.to_features = nn.Sequential( nn.Linear( in_features=context_features, out_features=context_mapping_features