Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

problem about reproducting RCAN using your project #58

Open
Senwang98 opened this issue Dec 8, 2020 · 0 comments
Open

problem about reproducting RCAN using your project #58

Senwang98 opened this issue Dec 8, 2020 · 0 comments

Comments

@Senwang98
Copy link

Senwang98 commented Dec 8, 2020

Hi, @Paper99
I am try to reproduct RCAN based on your code. my code:

import torch
from torch import nn as nn
import math
# from basicsr.models.archs.arch_util import Upsample, make_layer

def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)

class Upsample(nn.Sequential):
    """Upsample module.
    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. '
                             'Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)

class ChannelAttention(nn.Module):
    """Channel attention used in RCAN.
    Args:
        num_feat (int): Channel number of intermediate features.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
    """

    def __init__(self, num_feat, squeeze_factor=16):
        super(ChannelAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
            nn.Sigmoid())

    def forward(self, x):
        y = self.attention(x)
        return x * y


class RCAB(nn.Module):
    """Residual Channel Attention Block (RCAB) used in RCAN.
    Args:
        num_feat (int): Channel number of intermediate features.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        res_scale (float): Scale the residual. Default: 1.
    """

    def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
        super(RCAB, self).__init__()
        self.res_scale = res_scale

        self.rcab = nn.Sequential(
            nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(num_feat, num_feat, 3, 1, 1),
            ChannelAttention(num_feat, squeeze_factor))

    def forward(self, x):
        res = self.rcab(x) * self.res_scale
        return res + x


class ResidualGroup(nn.Module):
    """Residual Group of RCAB.
    Args:
        num_feat (int): Channel number of intermediate features.
        num_block (int): Block number in the body network.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        res_scale (float): Scale the residual. Default: 1.
    """

    def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
        super(ResidualGroup, self).__init__()

        self.residual_group = make_layer(
            RCAB,
            num_block,
            num_feat=num_feat,
            squeeze_factor=squeeze_factor,
            res_scale=res_scale)
        self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

    def forward(self, x):
        res = self.conv(self.residual_group(x))
        return res + x


class RCAN(nn.Module):
    """Residual Channel Attention Networks.
    Paper: Image Super-Resolution Using Very Deep Residual Channel Attention
        Networks
    Ref git repo: https://github.com/yulunzhang/RCAN.
    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64.
        num_group (int): Number of ResidualGroup. Default: 10.
        num_block (int): Number of RCAB in ResidualGroup. Default: 16.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        upscale (int): Upsampling factor. Support 2^n and 3.
            Default: 4.
        res_scale (float): Used to scale the residual in residual block.
            Default: 1.
        img_range (float): Image range. Default: 255.
        rgb_mean (tuple[float]): Image mean in RGB orders.
            Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
    """

    def __init__(self,
                 num_in_ch=3,
                 num_out_ch=3,
                 num_feat=64,
                 num_group=10,
                 num_block=16,
                 squeeze_factor=16,
                 upscale=2,
                 res_scale=1,
                 img_range=255.,
                 rgb_mean=(0.4488, 0.4371, 0.4040)):
        super(RCAN, self).__init__()

        self.img_range = img_range
        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)

        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(
            ResidualGroup,
            num_group,
            num_feat=num_feat,
            num_block=num_block,
            squeeze_factor=squeeze_factor,
            res_scale=res_scale)
        self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.upsample = Upsample(upscale, num_feat)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

    def forward(self, x):
        # print(x.shape)
        self.mean = self.mean.type_as(x)

        x = (x - self.mean) * self.img_range
        x = self.conv_first(x)
        res = self.conv_after_body(self.body(x))
        res += x

        x = self.conv_last(self.upsample(res))
        x = x / self.img_range + self.mean
        # print(x.shape)
        # exit()
        return x

this code can run, but the loss is very high(about 1e30). I feel so confused about this, can you give me suggestions?
my train.json:

{
    "mode": "sr",
    "use_cl": false,
    // "use_cl": true,
    "gpu_ids": [1],

    "scale": 2,
    "is_train": true,
    "use_chop": true,
    "rgb_range": 255,
    "self_ensemble": false,
    "save_image": false,

    "datasets": {
        "train": {
            "mode": "LRHR",
            "dataroot_HR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_HR_aug/x2",
            "dataroot_LR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_LR_aug/x2",
            "data_type": "npy",
            "n_workers": 8,
            "batch_size": 16,
            "LR_size": 48,
            "use_flip": true,
            "use_rot": true,
            "noise": "."
        },
        "val": {
            "mode": "LRHR",
            "dataroot_HR": "./results/HR/Set5/x2",
            "dataroot_LR": "./results/LR/LRBI/Set5/x2",
            "data_type": "img"
        }
    },

    "networks": {
        "which_model": "RCAN",
        "num_features": 64,
        "in_channels": 3,
        "out_channels": 3,
        "res_scale": 1,
        "num_resgroups":10, 
        "num_resblocks":20,
        "num_reduction":16
    },

    "solver": {
        "type": "ADAM",
        "learning_rate": 0.0002,
        "weight_decay": 0,
        "lr_scheme": "MultiStepLR",
        "lr_steps": [200, 400, 600, 800],
        "lr_gamma": 0.5,
        "loss_type": "l1",
        "manual_seed": 0,
        "num_epochs": 1000,
        "skip_threshold": 3,
        "split_batch": 1,
        "save_ckp_step": 100,
        "save_vis_step": 1,
        "pretrain": null,
        // "pretrain": "resume",
        "pretrained_path": "./experiments/RCAN_in3f64_x4/epochs/last_ckp.pth",
        "cl_weights": [1.0, 1.0, 1.0, 1.0]
    }
}
@Senwang98 Senwang98 changed the title reproduct RCAN using problem about reproducting RCAN using your project Dec 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant