Skip to content

Commit

Permalink
more thread-safe repeatability
Browse files Browse the repository at this point in the history
  • Loading branch information
ancorso committed Dec 14, 2023
1 parent a643675 commit 01a7a29
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
26 changes: 18 additions & 8 deletions examples/geothermal_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using CSV
using Random
using DataStructures
using Plots
default(framestyle = :box, color_palette=:seaborn_deep6, fontfamily="Computer Modern")
default(framestyle = :box, color_palette=:seaborn_deep6, fontfamily="Computer Modern", margin=5mm)

# Define random seeds
fix_solve_and_eval_seed = true # Whether the seed is set before each policy gen and evaluation. Seed is the eval index + test set. It is threadsafe.
Expand Down Expand Up @@ -140,14 +140,14 @@ option11_pol(pomdp) = FixedPolicy([Symbol("Option 10")])
option13_pol(pomdp) = FixedPolicy([Symbol("Option 11")])
all_policy_geo(pomdp) = EnsureParticleCount(FixedPolicy(obs_actions, BestCurrentOption(pomdp)), BestCurrentOption(pomdp), min_particles)
all_policy_econ(pomdp) = EnsureParticleCount(FixedPolicy(reverse(obs_actions), BestCurrentOption(pomdp)), BestCurrentOption(pomdp), min_particles)
random_policy_10(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.1), BestCurrentOption(pomdp), min_particles)
random_policy_4(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.25), BestCurrentOption(pomdp), min_particles)
random_policy_2(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.5), BestCurrentOption(pomdp), min_particles)
voi_policy(pomdp) = EnsureParticleCount(OneStepGreedyPolicy(;pomdp), BestCurrentOption(pomdp), min_particles)
sarsop_policy(pomdp) = EnsureParticleCount(solve(SARSOPSolver(), pomdp), BestCurrentOption(pomdp), min_particles)
random_policy_0_1(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.1), BestCurrentOption(pomdp), min_particles)
random_policy_0_25(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.25), BestCurrentOption(pomdp), min_particles)
random_policy_0_5(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.5), BestCurrentOption(pomdp), min_particles)
# voi_policy(pomdp) = EnsureParticleCount(OneStepGreedyPolicy(;pomdp), BestCurrentOption(pomdp), min_particles)
sarsop_policy(pomdp) = EnsureParticleCount(solve(SARSOPSolver(max_time=10.0), pomdp), BestCurrentOption(pomdp), min_particles)

# combine policies into a list
policies = [option7_pol, option11_pol, option13_pol, all_policy_geo, all_policy_econ, random_policy_10, random_policy_4, random_policy_2, sarsop_policy] # voi_policy
policies = [option7_pol, option11_pol, option13_pol, all_policy_geo, all_policy_econ, random_policy_0_1, random_policy_0_25, random_policy_0_5, sarsop_policy] # voi_policy
policy_names = ["Option 7", "Option 11", "Option 13", "All Data Policy (Geo First)", "All Data Policy (Econ First)", "Random Policy (Pstop=0.1)", "Random Policy (Pstop=0.25)", "Random Policy (Pstop=0.5)", "SARSOP Policy"] # "VOI Policy"

# Evaluate the policies on the test set
Expand Down Expand Up @@ -229,7 +229,17 @@ elseif split_by == :both
end

zip_fracs = [(g,e) for (g, e) in zip(geo_fracs, econ_fracs)]
pomdps_per_geo, test_sets_per_geo = create_pomdps_with_different_training_fractions(zip_fracs, scenario_csvs, geo_params, econ_params, obs_actions, Nbins; rng_seed=pomdp_gen_seed, discount=discount_factor, split_by)
pomdps_per_geo, test_sets_per_geo = create_pomdps_with_different_training_fractions(
zip_fracs,
scenario_csvs,
geo_params,
econ_params,
obs_actions,
Nbins;
rng_seed=pomdp_gen_seed,
discount=discount_factor,
split_by)


# Solve the policies and evaluate the results #<---- Uncomment the below lines to solve and eval the policies
results = Dict()
Expand Down
6 changes: 3 additions & 3 deletions src/pomdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,16 @@ function create_pomdps(scenario_csvs, geo_params, econ_params, obs_actions, Nbin
p = Progress(nfolds, 1, "Creating POMDPs...")
Threads.@threads for i in 1:nfolds
# Create a specific rng for each thread
thread_rng = MersenneTwister(i)
Random.seed!(rng_seed+i)

# Generate the train set by combining all of the folds except the ith one
train = train_sets[i]

# Discretize the observations
discrete_obs = get_discrete_observations(train, obs_actions, Nbins; rng=thread_rng)
discrete_obs = get_discrete_observations(train, obs_actions, Nbins;)

# Create categorical observation distributions
obs_dists = create_observation_distributions(train, obs_actions, discrete_obs, Nbins .* Nsamples_per_bin; rng=thread_rng)
obs_dists = create_observation_distributions(train, obs_actions, discrete_obs, Nbins .* Nsamples_per_bin;)

# Make the POMDP and return the val and test sets
pomdps[i] = InfoGatheringPOMDP(train, obs_actions, keys(scenario_csvs), discrete_obs, obs_dists, (o) -> nearest_neighbor_mapping(o, discrete_obs), discount)
Expand Down

0 comments on commit 01a7a29

Please sign in to comment.