From 01a7a29b536d14ff20534f5daf686f2d250ea3ab Mon Sep 17 00:00:00 2001 From: Anthony Corso Date: Thu, 14 Dec 2023 15:11:50 -0800 Subject: [PATCH] more thread-safe repeatability --- examples/geothermal_example.jl | 26 ++++++++++++++++++-------- src/pomdp.jl | 6 +++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/geothermal_example.jl b/examples/geothermal_example.jl index 3d6f3fa..71f0836 100644 --- a/examples/geothermal_example.jl +++ b/examples/geothermal_example.jl @@ -7,7 +7,7 @@ using CSV using Random using DataStructures using Plots -default(framestyle = :box, color_palette=:seaborn_deep6, fontfamily="Computer Modern") +default(framestyle = :box, color_palette=:seaborn_deep6, fontfamily="Computer Modern", margin=5mm) # Define random seeds fix_solve_and_eval_seed = true # Whether the seed is set before each policy gen and evaluation. Seed is the eval index + test set. It is threadsafe. @@ -140,14 +140,14 @@ option11_pol(pomdp) = FixedPolicy([Symbol("Option 10")]) option13_pol(pomdp) = FixedPolicy([Symbol("Option 11")]) all_policy_geo(pomdp) = EnsureParticleCount(FixedPolicy(obs_actions, BestCurrentOption(pomdp)), BestCurrentOption(pomdp), min_particles) all_policy_econ(pomdp) = EnsureParticleCount(FixedPolicy(reverse(obs_actions), BestCurrentOption(pomdp)), BestCurrentOption(pomdp), min_particles) -random_policy_10(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.1), BestCurrentOption(pomdp), min_particles) -random_policy_4(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.25), BestCurrentOption(pomdp), min_particles) -random_policy_2(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.5), BestCurrentOption(pomdp), min_particles) -voi_policy(pomdp) = EnsureParticleCount(OneStepGreedyPolicy(;pomdp), BestCurrentOption(pomdp), min_particles) -sarsop_policy(pomdp) = EnsureParticleCount(solve(SARSOPSolver(), pomdp), BestCurrentOption(pomdp), min_particles) +random_policy_0_1(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.1), BestCurrentOption(pomdp), min_particles) +random_policy_0_25(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.25), BestCurrentOption(pomdp), min_particles) +random_policy_0_5(pomdp) = EnsureParticleCount(RandPolicy(;pomdp, prob_terminal=0.5), BestCurrentOption(pomdp), min_particles) +# voi_policy(pomdp) = EnsureParticleCount(OneStepGreedyPolicy(;pomdp), BestCurrentOption(pomdp), min_particles) +sarsop_policy(pomdp) = EnsureParticleCount(solve(SARSOPSolver(max_time=10.0), pomdp), BestCurrentOption(pomdp), min_particles) # combine policies into a list -policies = [option7_pol, option11_pol, option13_pol, all_policy_geo, all_policy_econ, random_policy_10, random_policy_4, random_policy_2, sarsop_policy] # voi_policy +policies = [option7_pol, option11_pol, option13_pol, all_policy_geo, all_policy_econ, random_policy_0_1, random_policy_0_25, random_policy_0_5, sarsop_policy] # voi_policy policy_names = ["Option 7", "Option 11", "Option 13", "All Data Policy (Geo First)", "All Data Policy (Econ First)", "Random Policy (Pstop=0.1)", "Random Policy (Pstop=0.25)", "Random Policy (Pstop=0.5)", "SARSOP Policy"] # "VOI Policy" # Evaluate the policies on the test set @@ -229,7 +229,17 @@ elseif split_by == :both end zip_fracs = [(g,e) for (g, e) in zip(geo_fracs, econ_fracs)] -pomdps_per_geo, test_sets_per_geo = create_pomdps_with_different_training_fractions(zip_fracs, scenario_csvs, geo_params, econ_params, obs_actions, Nbins; rng_seed=pomdp_gen_seed, discount=discount_factor, split_by) +pomdps_per_geo, test_sets_per_geo = create_pomdps_with_different_training_fractions( + zip_fracs, + scenario_csvs, + geo_params, + econ_params, + obs_actions, + Nbins; + rng_seed=pomdp_gen_seed, + discount=discount_factor, + split_by) + # Solve the policies and evaluate the results #<---- Uncomment the below lines to solve and eval the policies results = Dict() diff --git a/src/pomdp.jl b/src/pomdp.jl index f4dc179..d0385d4 100644 --- a/src/pomdp.jl +++ b/src/pomdp.jl @@ -296,16 +296,16 @@ function create_pomdps(scenario_csvs, geo_params, econ_params, obs_actions, Nbin p = Progress(nfolds, 1, "Creating POMDPs...") Threads.@threads for i in 1:nfolds # Create a specific rng for each thread - thread_rng = MersenneTwister(i) + Random.seed!(rng_seed+i) # Generate the train set by combining all of the folds except the ith one train = train_sets[i] # Discretize the observations - discrete_obs = get_discrete_observations(train, obs_actions, Nbins; rng=thread_rng) + discrete_obs = get_discrete_observations(train, obs_actions, Nbins;) # Create categorical observation distributions - obs_dists = create_observation_distributions(train, obs_actions, discrete_obs, Nbins .* Nsamples_per_bin; rng=thread_rng) + obs_dists = create_observation_distributions(train, obs_actions, discrete_obs, Nbins .* Nsamples_per_bin;) # Make the POMDP and return the val and test sets pomdps[i] = InfoGatheringPOMDP(train, obs_actions, keys(scenario_csvs), discrete_obs, obs_dists, (o) -> nearest_neighbor_mapping(o, discrete_obs), discount)