Skip to content

Commit

Permalink
refactor: add JumpProblem to remake tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 3, 2024
1 parent 336c932 commit d4d5466
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ EnumX = "1"
FillArrays = "1.9"
FunctionWrappersWrappers = "0.1.3"
IteratorInterfaceExtensions = "^1"
JumpProcesses = "9.10.1"
LinearAlgebra = "1.9"
Logging = "1.9"
Markdown = "1.9"
Expand All @@ -81,7 +82,7 @@ SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.2"
Tables = "1.11"
TruncatedStacktraces = "1.4"
Zygote = "0.6.67"
Expand All @@ -92,6 +93,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Expand All @@ -102,8 +104,9 @@ RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq"]
test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "JumpProcesses", "SymbolicIndexingInterface"]
62 changes: 48 additions & 14 deletions test/remake.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit, SymbolicIndexingInterface
using JumpProcesses

@parameters σ ρ β
@variables t x(t) y(t) z(t)
Expand Down Expand Up @@ -51,8 +52,8 @@ sprob2 = remake(
u0 = [x => 2.0, sys.y => 1.2, :z => 1.0],
p ==> 29.0, sys.ρ => 11.0, => 3.0]
)
@test sprob.u0 isa Vector{<:Number}
@test sprob.p isa Vector{<:Number}
@test sprob2.u0 isa Vector{<:Number}
@test sprob2.p isa Vector{<:Number}
@test sprob2[x] == sprob2[sys.x] == sprob2[:x] == 2.0
@test sprob2[y] == sprob2[sys.y] == sprob2[:y] == 1.2
@test sprob2[z] == sprob2[sys.z] == sprob2[:z] == 1.0
Expand All @@ -78,17 +79,17 @@ dprob2 = remake(
u0 = [x => 2.0, sys.y => 1.2, :z => 1.0],
p ==> 29.0, sys.ρ => 11.0, => 3.0]
)
@test dprob.u0 isa Vector{<:Number}
@test dprob.p isa Vector{<:Number}
@test dprob2.u0 isa Vector{<:Number}
@test dprob2.p isa Vector{<:Number}
@test dprob2[x] == dprob2[sys.x] == dprob2[:x] == 2.0
@test dprob2[y] == dprob2[sys.y] == dprob2[:y] == 1.2
@test dprob2[z] == dprob2[sys.z] == dprob2[:z] == 1.0
@test getp(sys, σ)(dprob2) == 29.0
@test getp(sys, sys.ρ)(dprob2) == 11.0
@test getp(sys, )(dprob2) == 3.0
@test getp(de, σ)(dprob2) == 29.0
@test getp(de, sys.ρ)(dprob2) == 11.0
@test getp(de, )(dprob2) == 3.0

dprob3 = remake(dprob; p ==> 30.0]) # partial update
@test getp(sys, σ)(dprob3) == 30.0
@test getp(de, σ)(dprob3) == 30.0

# NonlinearProblem
@named ns = NonlinearSystem(
Expand All @@ -105,14 +106,47 @@ nlprob2 = remake(
u0 = [x => 2.0, sys.y => 1.2, :z => 1.0],
p ==> 29.0, sys.ρ => 11.0, => 3.0]
)
@test nlprob.u0 isa Vector{<:Number}
@test nlprob.p isa Vector{<:Number}
@test nlprob2.u0 isa Vector{<:Number}
@test nlprob2.p isa Vector{<:Number}
@test nlprob2[x] == nlprob2[sys.x] == nlprob2[:x] == 2.0
@test nlprob2[y] == nlprob2[sys.y] == nlprob2[:y] == 1.2
@test nlprob2[z] == nlprob2[sys.z] == nlprob2[:z] == 1.0
@test getp(sys, σ)(nlprob2) == 29.0
@test getp(sys, sys.ρ)(nlprob2) == 11.0
@test getp(sys, )(nlprob2) == 3.0
@test getp(ns, σ)(nlprob2) == 29.0
@test getp(ns, sys.ρ)(nlprob2) == 11.0
@test getp(ns, )(nlprob2) == 3.0

nlprob3 = remake(nlprob; p ==> 30.0]) # partial update
@test getp(sys, σ)(nlprob3) == 30.0
@test getp(ns, σ)(nlprob3) == 30.0

@parameters β γ
@variables t S(t) I(t) R(t)
rate₁ = β*S*I
affect₁ = [S ~ S - 1, I ~ I + 1]
rate₂ = γ*I
affect₂ = [I ~ I - 1, R ~ R + 1]
j₁ = ConstantRateJump(rate₁,affect₁)
j₂ = ConstantRateJump(rate₂,affect₂)
j₃ = MassActionJump(2*β+γ, [R => 1], [S => 1, R => -1])
@named js = JumpSystem([j₁,j₂,j₃], t, [S,I,R], [β,γ])

u₀map = [S => 999, I => 1, R => 0]
parammap ==> 0.1 / 1000, γ => 0.01]
tspan = (0.0, 250.0)
jump_dprob = DiscreteProblem(js, u₀map, tspan, parammap)
jprob = JumpProblem(js, jump_dprob, Direct())

jprob2 = remake(
jprob;
u0 = [S => 900, js.I => 2, :R => 0.1],
p ==> 0.2 / 1000, js.γ => 11.0]
)
@test jprob2.prob.u0 isa Vector{<:Number}
@test jprob2.prob.p isa Vector{<:Number}
@test jprob2[S] == jprob2[js.S] == jprob2[:S] == 900.0
@test jprob2[I] == jprob2[js.I] == jprob2[:I] == 2.0
@test jprob2[R] == jprob2[js.R] == jprob2[:R] == 0.1
@test getp(js, β)(jprob2) == 0.2 / 1000
@test getp(js, js.γ)(jprob2) == 11.0

jprob3 = remake(jprob; p = [ => 0.3 / 1000]) # partial update
@test getp(js, β)(jprob3) == 0.3 / 1000

0 comments on commit d4d5466

Please sign in to comment.