-
Notifications
You must be signed in to change notification settings - Fork 2
/
mod_initialiser.f90
125 lines (107 loc) · 4.47 KB
/
mod_initialiser.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
!!!#############################################################################
!!! Code written by Ned Thaddeus Taylor
!!! Code part of the ATHENA library - a feedforward neural network library
!!!#############################################################################
!!! module contains initialiser functions
!!! module includes the following procedures:
!!! - initialiser_setup - set up initialiser
!!! - get_default_initialiser - get default initialiser based on activation ...
!!! ... function
!!!#############################################################################
!! Examples of initialsers in keras: https://keras.io/api/layers/initializers/
!!!#############################################################################
module initialiser
use misc, only: to_lower
use custom_types, only: initialiser_type
use initialiser_glorot, only: glorot_uniform, glorot_normal
use initialiser_he, only: he_uniform, he_normal
use initialiser_lecun, only: lecun_uniform, lecun_normal
use initialiser_ones, only: ones
use initialiser_zeros, only: zeros
use initialiser_ident, only: ident
use initialiser_gaussian, only: gaussian
implicit none
private
public :: initialiser_setup, get_default_initialiser
contains
!!!#############################################################################
!!! get default initialiser based on activation function (and if a bias)
!!!#############################################################################
!!! activation = (S, in) activation function name
!!! is_bias = (B, in) if true, then initialiser is for bias
!!! name = (S, out) name of default initialiser
function get_default_initialiser(activation, is_bias) result(name)
implicit none
character(*), intent(in) :: activation
logical, optional, intent(in) :: is_bias
character(:), allocatable :: name
!!--------------------------------------------------------------------------
!! if bias, use default initialiser of zero
!!--------------------------------------------------------------------------
if(present(is_bias))then
if(is_bias) name = "zeros"
return
end if
!!--------------------------------------------------------------------------
!! set default initialiser based on activation
!!--------------------------------------------------------------------------
if(trim(activation).eq."selu")then
name = "lecun_normal"
elseif(index(activation,"elu").ne.0)then
name = "he_uniform"
elseif(trim(activation).eq."batch")then
name = "gaussian"
else
name = "glorot_uniform"
end if
end function get_default_initialiser
!!!#############################################################################
!!!#############################################################################
!!! set up initialiser
!!!#############################################################################
!!! name = (S, in) name of initialiser
!!! error = (I, out) error code
!!! initialiser = (O, out) initialiser function
function initialiser_setup(name, error) result(initialiser)
implicit none
class(initialiser_type), allocatable :: initialiser
character(*), intent(in) :: name
integer, optional, intent(out) :: error
!!--------------------------------------------------------------------------
!! set initialiser function
!!--------------------------------------------------------------------------
select case(trim(to_lower(name)))
case("glorot_uniform")
initialiser = glorot_uniform
case("glorot_normal")
initialiser = glorot_normal
case("he_uniform")
initialiser = he_uniform
case("he_normal")
initialiser = he_normal
case("lecun_uniform")
initialiser = lecun_uniform
case("lecun_normal")
initialiser = lecun_normal
case("ones")
initialiser = ones
case("zeros")
initialiser = zeros
case("ident")
initialiser = ident
case("gaussian")
initialiser = gaussian
case("normal")
initialiser = gaussian
case default
if(present(error))then
error = -1
return
else
stop "Incorrect initialiser name given '"//trim(to_lower(name))//"'"
end if
end select
end function initialiser_setup
!!!#############################################################################
end module initialiser
!!!#############################################################################