-
Notifications
You must be signed in to change notification settings - Fork 0
/
psnake_tensorflow.py
100 lines (93 loc) · 3.74 KB
/
psnake_tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
class Psnake(Layer):
"""Parametric Snake activation (PSnake).
Parametric version of Snake from "Neural Networks Fail to Learn Periodic Functions
and How to Fix It"
It follows:
```
f(x) = x + (1 - cos(2 * self.alpha * x))/(2 * alpha)
```
where `alpha` is a learned array with the same shape as x.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as the input.
Args:
alpha_initializer: Initializer function for the weights.
alpha_regularizer: Regularizer for the weights.
alpha_constraint: Constraint for the weights.
shared_axes: The axes along which to share learnable
parameters for the activation function.
For example, if the incoming feature maps
are from a 2D convolution
with output shape `(batch, height, width, channels)`,
and you wish to share parameters across space
so that each filter only has one set of parameters,
set `shared_axes=[1, 2]`.
"""
def __init__(
self,
alpha_initializer="ones",
alpha_regularizer=None,
alpha_constraint=None,
shared_axes=None,
**kwargs
):
super(Psnake, self).__init__(**kwargs)
self.supports_masking = True
self.alpha_initializer = initializers.get(alpha_initializer)
self.alpha_regularizer = regularizers.get(alpha_regularizer)
self.alpha_constraint = constraints.get(alpha_constraint)
if shared_axes is None:
self.shared_axes = None
elif not isinstance(shared_axes, (list, tuple)):
self.shared_axes = [shared_axes]
else:
self.shared_axes = list(shared_axes)
@tf_utils.shape_type_conversion
def build(self, input_shape):
param_shape = list(input_shape[1:])
if self.shared_axes is not None:
for i in self.shared_axes:
param_shape[i - 1] = 1
self.alpha = self.add_weight(
shape=param_shape,
name="alpha",
initializer=self.alpha_initializer,
regularizer=self.alpha_regularizer,
constraint=self.alpha_constraint,
)
# Set input spec
axes = {}
if self.shared_axes:
for i in range(1, len(input_shape)):
if i not in self.shared_axes:
axes[i] = input_shape[i]
self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
self.built = True
def call(self, inputs):
up = inputs + (1 - tf.math.cos(2 * self.alpha * inputs))
down = 2 * self.alpha
return up / down
def get_config(self):
config = {
"alpha_initializer": initializers.serialize(self.alpha_initializer),
"alpha_regularizer": regularizers.serialize(self.alpha_regularizer),
"alpha_constraint": constraints.serialize(self.alpha_constraint),
"shared_axes": self.shared_axes,
}
base_config = super(Psnake, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape