-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_max_LL.jl
102 lines (76 loc) · 2.95 KB
/
test_max_LL.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
using PyPlot
using Statistics
using Optim
using ForwardDiff
using JLD
using LineSearches
path_functions="/home/genis/wm_mice/"
path_figures="/home/genis/wm_mice/figures/"
include(path_functions*"functions_wm_mice.jl")
include(path_functions*"function_simulations.jl")
args=["mu_k","c2","c4","x0","mu_b","sigma","beta_w","beta_l","tau_w","tau_l","PDwDw","PBiasBias"]
x=[ 0.3, 1.2, 1.0, 0.15, -0.05, 0.3, 3.0, -1.0, 10, 10, 0.9, 0.8]
param=make_dict2(args,x)
delays=[0.0,100,200,300,500,800,1000]
Ntrials=Int(1e4)
choices,state,stim,past_choices,past_rewards,idelays=create_data(Ntrials,delays,args,x)
function LL_f(y)
#println("hola")
z=zeros(typeof(y[1]),length(x))
z[:]=x[:]
### sigma c2
#z[2]=y[1]
#z[6]=y[2]
### mu_k beta_w
z[1]=y[1]
z[7]=y[2]
#println("vamos")
return Compute_negative_LL(stim,delays,idelays,choices,past_choices,past_rewards,args,z)
end
y=[x[1],x[7]]
ll_original=Compute_negative_LL(stim,delays,idelays,choices,past_choices,past_rewards,args,x)
#c2 sigma
#lower=[0.2,0.05]
#upper=[4,0.6]
#muk betw
lower=[-10.0,-10.0]
upper=[10.0,10.0]
NDataSets=2
Nconditions=100
Ymin=zeros(NDataSets,Nconditions,length(lower))
Yini=zeros(NDataSets,Nconditions,length(lower))
LL=zeros(NDataSets,Nconditions)
Hess=zeros(NDataSets,Nconditions,length(y),length(y))
LlOriginal=zeros(NDataSets)
for iDataSet in 1:NDataSets
println(iDataSet)
choices,state,stim,past_choices,past_rewards,idelays=create_data(Ntrials,delays,args,x)
LlOriginal[iDataSet]=Compute_negative_LL(stim,delays,idelays,choices,past_choices,past_rewards,args,x)
function LL_f2(y)
#println("hola")
z=zeros(typeof(y[1]),length(x))
z[:]=x[:]
### sigma c2
#z[2]=y[1]
#z[6]=y[2]
### mu_k beta_w
z[1]=y[1]
z[7]=y[2]
#println("vamos")
return Compute_negative_LL(stim,delays,idelays,choices,past_choices,past_rewards,args,z)
end
for icondition in 1:Nconditions
#choices,state,stim,past_choices,past_rewards,idelays=create_data(Ntrials,delays,args,x)
aux=rand(length(upper))
y=aux.*(upper-lower).+lower
res=optimize(LL_f2,lower,upper, y, Fminbox(LBFGS(linesearch = BackTracking(order=2))); autodiff = :forward)
Ymin[iDataSet,icondition,:]=res.minimizer
Yini[iDataSet,icondition,:]=res.initial_x
Hess[iDataSet,icondition,:,:]=ForwardDiff.hessian(LL_f2,res.minimizer)
LL[iDataSet,icondition]=res.minimum
end
end
#filename_save="/home/genis/wm_mice/synthetic_data/minimize_sigma_c2_Ntrials"*string(Ntrials)*".jld"
#filename_save="/home/genis/wm_mice/synthetic_data/minimize_sigma_c2_Ntrials"*string(Ntrials)*"_NDataSets"*string(NDataSets)*".jld"
filename_save="/home/genis/wm_mice/synthetic_data/minimize_muk_betaw_Ntrials"*string(Ntrials)*"_NDataSets"*string(NDataSets)*".jld"
save(filename_save,"x",x,"Ymin",Ymin,"Yini",Yini,"LL",LL,"Hess",Hess,"args",args,"LlOriginal",LlOriginal)