Skip to content

Commit

Permalink
Small fix for funs.jl.
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Jun 16, 2024
1 parent 663d857 commit c935029
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/algos/time_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ function time_iteration_workspace(dmodel; interp_mode=:linear, improve=false, de
vars = variables(dmodel.model.controls)
φ = DFun(dmodel.model.states, x0, vars; interp_mode=interp_mode)

# if improve
if improve
L = Dolo.dF_2(dmodel, x1, φ)
tt = (;x0, x1, x2, r0, dx, J, L, φ)
# else
# tt = (;x0, x1, x2, r0, dx, J, φ)
# end
else
tt = (;x0, x1, x2, r0, dx, J, φ)
end

return adapt(dest, tt)

Expand Down
18 changes: 12 additions & 6 deletions src/funs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,19 @@ function fit!(φ::DFun, x::GVector{G}) where G<:CGrid

end

## PGrid
## PGrid ( SGrid × CGrid )


function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
f(x.loc...)
end

function (f::DFun{A,B,I,vars})(i::Int, x::SVector{d2, U}) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U
f.itp[i](x)
end

function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
f(x.loc...)
function (f::DFun{A,B,I,vars})(i::Int, j::Int) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U
f.values[i,j]
end

function (f::DFun{A,B,I,vars})(x::Tuple) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
Expand Down Expand Up @@ -148,9 +154,9 @@ end

# Compatibility calls

(f::DFun)(x::Real) = f(SVector(x))
(f::DFun)(x::Real, y::Real) = f(SVector(x,y))
(f::DFun)(x::Vector{SVector{d,<:Real}}) where d = [f(e) for e in x]
# (f::DFun)(x::Real) = f(SVector(x))
# (f::DFun)(x::Real, y::Real) = f(SVector(x,y))
# (f::DFun)(x::Vector{SVector{d,<:Real}}) where d = [f(e) for e in x]


ndims(df::DFun) = ndims(df.domain)
Expand Down

0 comments on commit c935029

Please sign in to comment.