-
Notifications
You must be signed in to change notification settings - Fork 46
/
model.py
74 lines (55 loc) · 2.56 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from tensorflow.python.keras.layers import Add, BatchNormalization, Conv2D, Dense, Flatten, Input, LeakyReLU, PReLU, Lambda
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.applications.vgg19 import VGG19
from utils import subpixel_conv2d,normalize_01,normalize_m11,normalize,denormalize,denormalize_m11
def upsample(x_in, num_filters):
x = Conv2D(num_filters, kernel_size=3, padding='same')(x_in)
x = Lambda(subpixel_conv2d(scale=2))(x)
return PReLU(shared_axes=[1, 2])(x)
def res_block(x_in, num_filters, momentum=0.8):
x = Conv2D(num_filters, kernel_size=3, padding='same')(x_in)
x = BatchNormalization(momentum=momentum)(x)
x = PReLU(shared_axes=[1, 2])(x)
x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
x = BatchNormalization(momentum=momentum)(x)
x = Add()([x_in, x])
return x
def generator(num_filters=64, num_res_blocks=16):
x_in = Input(shape=(None, None, 3))
x = Lambda(normalize_01)(x_in)
x = Conv2D(num_filters, kernel_size=9, padding='same')(x)
x = x_1 = PReLU(shared_axes=[1, 2])(x)
for _ in range(num_res_blocks):
x = res_block(x, num_filters)
x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x_1, x])
x = upsample(x, num_filters * 4)
x = upsample(x, num_filters * 4)
x = Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
x = Lambda(denormalize_m11)(x)
return Model(x_in, x)
def discriminator_block(x_in, num_filters, strides=1, batchnorm=True, momentum=0.8):
x = Conv2D(num_filters, kernel_size=3, strides=strides, padding='same')(x_in)
if batchnorm:
x = BatchNormalization(momentum=momentum)(x)
return LeakyReLU(alpha=0.2)(x)
def discriminator(num_filters=64,HR_SIZE=96):
x_in = Input(shape=(HR_SIZE, HR_SIZE, 3))
x = Lambda(normalize_m11)(x_in)
x = discriminator_block(x, num_filters, batchnorm=False)
x = discriminator_block(x, num_filters, strides=2)
x = discriminator_block(x, num_filters * 2)
x = discriminator_block(x, num_filters * 2, strides=2)
x = discriminator_block(x, num_filters * 4)
x = discriminator_block(x, num_filters * 4, strides=2)
x = discriminator_block(x, num_filters * 8)
x = discriminator_block(x, num_filters * 8, strides=2)
x = Flatten()(x)
x = Dense(1024)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dense(1, activation='sigmoid')(x)
return Model(x_in, x)
def vgg():
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
return Model(vgg.input, vgg.layers[20].output)