diff --git a/loralib/layers.py b/loralib/layers.py index 0e54a64b..63786560 100644 --- a/loralib/layers.py +++ b/loralib/layers.py @@ -247,6 +247,8 @@ class ConvLoRA(nn.Module, LoRALayer): def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): super(ConvLoRA, self).__init__() self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) + for name, param in self.conv.named_parameters(): + self.register_parameter(name, param) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert isinstance(kernel_size, int) # Actual trainable parameters