-
Notifications
You must be signed in to change notification settings - Fork 85
/
nf_metrics.f90
72 lines (56 loc) · 1.79 KB
/
nf_metrics.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
module nf_metrics
!! This module provides a collection of metric functions.
implicit none
private
public :: metric_type
public :: corr
public :: maxabs
type, abstract :: metric_type
contains
procedure(metric_interface), nopass, deferred :: eval
end type metric_type
abstract interface
pure function metric_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function metric_interface
end interface
type, extends(metric_type) :: corr
!! Pearson correlation
contains
procedure, nopass :: eval => corr_eval
end type corr
type, extends(metric_type) :: maxabs
!! Maximum absolute difference
contains
procedure, nopass :: eval => maxabs_eval
end type maxabs
contains
pure function corr_eval(true, predicted) result(res)
!! Pearson correlation function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting correlation value
real :: m_true, m_pred
m_true = sum(true) / size(true)
m_pred = sum(predicted) / size(predicted)
res = dot_product(true - m_true, predicted - m_pred) / &
sqrt(sum((true - m_true)**2)*sum((predicted - m_pred)**2))
end function corr_eval
pure function maxabs_eval(true, predicted) result(res)
!! Maximum absolute difference function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting maximum absolute difference value
res = maxval(abs(true - predicted))
end function maxabs_eval
end module nf_metrics