Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some assertion statements to improved clarity in the adp code. #23

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions stable_audio_tools/models/adp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
assert kernel_multiplier % 2 == 0, "Kernel multiplier should be divisible by 2."

return Conv1d(
in_channels=in_channels,
Expand Down Expand Up @@ -217,7 +217,8 @@ def __init__(
)

if self.use_mapping:
assert exists(context_mapping_features)
assert_message = "Ensure that context mapping features exist when use_mapping is set to True."
assert exists(context_mapping_features), assert_message
self.to_scale_shift = MappingToScaleShift(
features=context_mapping_features, channels=out_channels
)
Expand Down Expand Up @@ -581,7 +582,7 @@ class LearnedPositionalEmbedding(nn.Module):

def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
assert (dim % 2) == 0, "Ensure that the value of 'dim' can be divided by 2."
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))

Expand Down Expand Up @@ -657,10 +658,11 @@ def __init__(
)

if self.use_transformer:
assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer."
assert (
exists(attention_heads)
and exists(attention_multiplier)
)
), assert_message

attention_features = default(attention_features, channels // attention_heads)

Expand Down Expand Up @@ -765,10 +767,11 @@ def __init__(
)

if self.use_transformer:
assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer."
assert (
exists(attention_heads)
and exists(attention_multiplier)
)
), assert_message

attention_features = default(attention_features, channels // attention_heads)

Expand Down Expand Up @@ -857,10 +860,11 @@ def __init__(
)

if self.use_transformer:
assert_message = "Both attention_heads and attention_multiplier must be specified when using transformer."
assert (
exists(attention_heads)
and exists(attention_multiplier)
)
), assert_message

attention_features = default(attention_features, channels // attention_heads)

Expand Down Expand Up @@ -972,7 +976,8 @@ def __init__(
)

if use_context_time:
assert exists(context_mapping_features)
assert_message = "When use_context_time is set to True, context_mapping_features must be also specified."
assert exists(context_mapping_features), assert_message
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
Expand All @@ -981,7 +986,8 @@ def __init__(
)

if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
assert_message = "When use_context_features is set to True, context_feature and context_mapping_features must be both specified."
assert exists(context_features) and exists(context_mapping_features), assert_message
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
Expand Down