-
Notifications
You must be signed in to change notification settings - Fork 153
/
lswolfe.lua
192 lines (173 loc) · 5.8 KB
/
lswolfe.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
--[[ A Line Search satisfying the Wolfe conditions
ARGS:
- `opfunc` : a function (the objective) that takes a single input (X),
the point of evaluation, and returns f(X) and df/dX
- `x` : initial point / starting location
- `t` : initial step size
- `d` : descent direction
- `f` : initial function value
- `g` : gradient at initial location
- `gtd` : directional derivative at starting location
- `options.c1` : sufficient decrease parameter
- `options.c2` : curvature parameter
- `options.tolX` : minimum allowable step length
- `options.maxIter` : maximum nb of iterations
RETURN:
- `f` : function value at x+t*d
- `g` : gradient value at x+t*d
- `x` : the next x (=x+t*d)
- `t` : the step length
- `lsFuncEval` : the number of function evaluations
]]
function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
-- options
options = options or {}
local c1 = options.c1 or 1e-4
local c2 = options.c2 or 0.9
local tolX = options.tolX or 1e-9
local maxIter = options.maxIter or 20
local isverbose = options.verbose or false
-- some shortcuts
local abs = torch.abs
local min = math.min
local max = math.max
-- verbose function
local function verbose(...)
if isverbose then print('<optim.lswolfe> ', ...) end
end
-- evaluate objective and gradient using initial step
local x_init = x:clone()
x:add(t,d)
local f_new,g_new = opfunc(x)
local lsFuncEval = 1
local gtd_new = g_new * d
-- bracket an interval containing a point satisfying the Wolfe
-- criteria
local LSiter,t_prev,done = 0,0,false
local f_prev,g_prev,gtd_prev = f,g:clone(),gtd
local bracket,bracketFval,bracketGval
while LSiter < maxIter do
-- check conditions:
if (f_new > (f + c1*t*gtd)) or (LSiter > 1 and f_new >= f_prev) then
bracket = x.new{t_prev,t}
bracketFval = x.new{f_prev,f_new}
bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g_prev
bracketGval[2] = g_new
break
elseif abs(gtd_new) <= -c2*gtd then
bracket = x.new{t}
bracketFval = x.new{f_new}
bracketGval = x.new(1,g_new:size(1))
bracketGval[1] = g_new
done = true
break
elseif gtd_new >= 0 then
bracket = x.new{t_prev,t}
bracketFval = x.new{f_prev,f_new}
bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g_prev
bracketGval[2] = g_new
break
end
-- interpolate:
local tmp = t_prev
t_prev = t
local minStep = t + 0.01*(t-tmp)
local maxStep = t*10
t = optim.polyinterp(x.new{{tmp,f_prev,gtd_prev},
{t,f_new,gtd_new}},
minStep, maxStep)
-- next step:
f_prev = f_new
g_prev = g_new:clone()
gtd_prev = gtd_new
x[{}] = x_init
x:add(t,d)
f_new,g_new = opfunc(x)
lsFuncEval = lsFuncEval + 1
gtd_new = g_new * d
LSiter = LSiter + 1
end
-- reached max nb of iterations?
if LSiter == maxIter then
bracket = x.new{0,t}
bracketFval = x.new{f,f_new}
bracketGval = x.new(2,g_new:size(1))
bracketGval[1] = g
bracketGval[2] = g_new
end
-- zoom phase: we now have a point satisfying the criteria, or
-- a bracket around it. We refine the bracket until we find the
-- exact point satisfying the criteria
local insufProgress = false
local LOposRemoved = 0
while not done and LSiter < maxIter do
-- find high and low points in bracket
local f_LO,LOpos = bracketFval:min(1)
LOpos = LOpos[1] f_LO = f_LO[1]
local HIpos = -LOpos+3
-- compute new trial value
t = optim.polyinterp(x.new{{bracket[1],bracketFval[1],bracketGval[1]*d},
{bracket[2],bracketFval[2],bracketGval[2]*d}})
-- test what we are making sufficient progress
if min(bracket:max()-t,t-bracket:min())/(bracket:max()-bracket:min()) < 0.1 then
if insufProgress or t>=bracket:max() or t <= bracket:min() then
if abs(t-bracket:max()) < abs(t-bracket:min()) then
t = bracket:max()-0.1*(bracket:max()-bracket:min())
else
t = bracket:min()+0.1*(bracket:max()-bracket:min())
end
insufProgress = false
else
insufProgress = true
end
else
insufProgress = false
end
-- Evaluate new point
x[{}] = x_init
x:add(t,d)
f_new,g_new = opfunc(x)
lsFuncEval = lsFuncEval + 1
gtd_new = g_new * d
LSiter = LSiter + 1
if f_new > f + c1*t*gtd or f_new >= f_LO then
-- Armijo condition not satisfied or not lower than lowest point
bracket[HIpos] = t
bracketFval[HIpos] = f_new
bracketGval[HIpos] = g_new
else
if abs(gtd_new) <= - c2*gtd then
-- Wolfe conditions satisfied
done = true
elseif gtd_new*(bracket[HIpos]-bracket[LOpos]) >= 0 then
-- Old HI becomes new LO
bracket[HIpos] = bracket[LOpos]
bracketFval[HIpos] = bracketFval[LOpos]
bracketGval[HIpos] = bracketGval[LOpos]
end
-- New point becomes new LO
bracket[LOpos] = t
bracketFval[LOpos] = f_new
bracketGval[LOpos] = g_new
end
-- done?
if not done and abs((bracket[1]-bracket[2])*gtd_new) < tolX then
break
end
end
-- be verbose
if LSiter == maxIter then
verbose('reached max number of iterations')
end
-- return stuff
local _,LOpos = bracketFval:min(1)
LOpos = LOpos[1]
t = bracket[LOpos]
f_new = bracketFval[LOpos]
g_new = bracketGval[LOpos]
x[{}] = x_init
x:add(t,d)
return f_new,g_new,x,t,lsFuncEval
end