Skip to content

Commit

Permalink
Some type specs
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Nov 10, 2024
1 parent 33ba49f commit 0a74219
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
20 changes: 20 additions & 0 deletions dev/dl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Dolo

model = include("../examples/ymodels/consumption_savings_iid.jl")


# dmodel = Dolo.discretize(model)

# wk = Dolo.time_iteration_workspace(dmodel)

# r = Dolo.F(dmodel, wk.x0, wk.φ)

import Dolo: F

function F(model::Dolo.AModel,s,φ)



end

Dolo.rand(model.states)
28 changes: 14 additions & 14 deletions src/algos/time_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@



function F(dmodel::M, s::QP, x::SVector{d,T}, φ::Union{Policy, GArray, DFun}) where M where d where T
function F(dmodel::ADModel, s::QP, x::SVector{d,T}, φ::Union{Policy, GArray, DFun}) where d where T

r = zero(SVector{d,T})

Expand All @@ -39,7 +39,7 @@ function F(dmodel::M, s::QP, x::SVector{d,T}, φ::Union{Policy, GArray, DFun}) w
end


F(model, controls::GArray, φ::Union{GArray, DFun}) =
F(model::ADModel, controls::GArray, φ::Union{GArray, DFun}) =
GArray(
model.grid,
[
Expand All @@ -48,7 +48,7 @@ F(model, controls::GArray, φ::Union{GArray, DFun}) =
],
)

function F!(out, model::M, controls::XT, φ::Union{GArray, DFun}) where M where XT<:GArray{PG, Vector{SVector{d,T}}} where PG where d where T
function F!(out, model::ADModel, controls::XT, φ::Union{GArray, DFun}) where XT<:GArray{PG, Vector{SVector{d,T}}} where PG where d where T

for s in enum(model.grid)
x = controls[s.loc...]
Expand All @@ -59,7 +59,7 @@ function F!(out, model::M, controls::XT, φ::Union{GArray, DFun}) where M where
end


function F!(out, model::M, controls::XT, φ::Union{GArray, DFun}, ::Nothing) where M where XT<:GArray{PG, Vector{SVector{d,T}}} where PG where d where T
function F!(out, model::ADModel, controls::XT, φ::Union{GArray, DFun}, ::Nothing) where XT<:GArray{PG, Vector{SVector{d,T}}} where PG where d where T

for s in enum(model.grid)
x = controls[s.loc...]
Expand All @@ -71,9 +71,9 @@ end


## no alloc
dF_1(model, s, x::SVector, φ) = ForwardDiff.jacobian(u->F(model, s, u, φ), x)
dF_1(model::ADModel, s, x::SVector, φ) = ForwardDiff.jacobian(u->F(model, s, u, φ), x)

dF_1(model, controls::GArray, φ::Union{GArray, DFun}) =
dF_1(model::ADModel, controls::GArray, φ::Union{GArray, DFun}) =
GArray( # this shouldn't be needed
model.grid,
[
Expand All @@ -85,11 +85,11 @@ dF_1(model, controls::GArray, φ::Union{GArray, DFun}) =

## no alloc

function dF_1!(out, model, controls::GArray, φ::Union{GArray, DFun}, ::Nothing)
function dF_1!(out::ADModel, model, controls::GArray, φ::Union{GArray, DFun}, ::Nothing)
dF_1!(out, model, controls::GArray, φ::Union{GArray, DFun})
end

function dF_1!(out, model, controls::GArray, φ::Union{GArray, DFun})
function dF_1!(out, model::ADModel, controls::GArray, φ::Union{GArray, DFun})

i = 0
for s in enum(model.grid)
Expand All @@ -106,20 +106,20 @@ end #### no alloc

# warning: the versions of dF_2 don't have complementarities

dF_2(model, s, x::SVector, φ::GArray, dφ::GArray) =
dF_2(model::ADModel, s, x::SVector, φ::GArray, dφ::GArray) =
sum(
w*ForwardDiff.jacobian(u->arbitrage(model,s,x,S,u), φ(S))* (S)
for (w, S) in τ(model, s, x)
) ### no alloc


dF_2(model, controls::GArray, φ::GArray, dφ::GArray) =
dF_2(model::ADModel, controls::GArray, φ::GArray, dφ::GArray) =
GArray(
model.grid,
[(dF_2(model,s,x,φ,dφ)) for (s,x) in zip(enum(model.grid), controls) ],
)

function dF_2!(out::GArray, model, controls::GArray, φ::GArray, dφ::GArray)
function dF_2!(out::GArray, model::ADModel, controls::GArray, φ::GArray, dφ::GArray)
for (n, (s,x)) in enumerate(zip(enum(model.grid), controls))
out[n] = dF_2(model, s, x, φ, dφ)
end
Expand Down Expand Up @@ -158,11 +158,11 @@ function time_iteration_workspace(dmodel; interp_mode=:linear, improve=false, de

end

function newton_workspace(model; interp_mode=:linear)
function newton_workspace(dmodel::ADModel; interp_mode=:linear)


res = time_iteration_workspace(model; interp_mode=interp_mode)
T = Dolo.dF_2(model, res.x0, res.φ)
res = time_iteration_workspace(dmodel; interp_mode=interp_mode)
T = Dolo.dF_2(dmodel, res.x0, res.φ)
res = merge(res, (;T=T,memn=(;du=deepcopy(res.x0), dv=deepcopy(res.x0))))
return res
end
Expand Down

0 comments on commit 0a74219

Please sign in to comment.