Skip to content

Commit

Permalink
adaptively discount reset_prob
Browse files Browse the repository at this point in the history
  • Loading branch information
MasWag committed May 24, 2023
1 parent b57a922 commit 50a56c9
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions src/ProbBlackBoxChecking.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def initialize_strategy_bridge_and_smc(sul, prism_model_path, prism_adv_path, sp
# PRISM + SMC による反例探索オラクル
class ProbBBReachOracle(RandomWalkEqOracle):
def __init__(self, prism_model_path, prism_adv_path, prism_prop_path, ltl_prop_path, alphabet: list, sul: SUL,
smc_max_exec=5000, num_steps=5000, reset_after_cex=True, reset_prob=0.09, statistical_test_bound=0.025,
only_classical_equivalence_testing=False,
smc_max_exec=5000, num_steps=5000, reset_after_cex=True, initial_reset_prob=0.25,
statistical_test_bound=0.025, only_classical_equivalence_testing=False,
output_dir='results', save_files_for_each_round=False, debug=False):
self.prism_model_path = prism_model_path
self.prism_adv_path = prism_adv_path
Expand All @@ -208,7 +208,15 @@ def __init__(self, prism_model_path, prism_adv_path, prism_prop_path, ltl_prop_p
self.save_files_for_each_round = save_files_for_each_round
self.debug = debug
self.rounds = 0
super().__init__(alphabet, sul=sul, num_steps=num_steps, reset_after_cex=reset_after_cex, reset_prob=reset_prob)
# We discount the reset probability so that any length of traces are sampled in the limit.
self.reset_prob_discount = 0.90
self.learned_strategy = None
super().__init__(alphabet, sul=sul, num_steps=num_steps, reset_after_cex=reset_after_cex,
reset_prob=initial_reset_prob)

def discount_reset_prob(self):
logging.info(f"discount reset_prob to {self.reset_prob}")
self.reset_prob *= self.reset_prob_discount

def find_cex(self, hypothesis):
self.rounds += 1
Expand Down Expand Up @@ -268,6 +276,8 @@ def find_cex(self, hypothesis):
logging.info("Run equivalence testing of L*mdp.")
cex = super().find_cex(hypothesis)
logging.info(f"CEX from EQ testing : {cex}")
if cex is None:
self.discount_reset_prob()
return cex

if not os.path.isfile(self.prism_adv_path):
Expand All @@ -276,6 +286,8 @@ def find_cex(self, hypothesis):
logging.info("Run equivalence testing of L*mdp.")
cex = super().find_cex(hypothesis)
logging.info(f"CEX from EQ testing : {cex}")
if cex is None:
self.discount_reset_prob()
return cex

self.learned_strategy = self.prism_adv_path
Expand Down Expand Up @@ -328,6 +340,8 @@ def find_cex(self, hypothesis):
logging.info("Run equivalence testing of L*mdp.")
cex = super().find_cex(hypothesis) # equivalence testing
logging.info(f'CEX from EQ testing : {cex}')
if cex is None:
self.discount_reset_prob()

return cex

Expand Down Expand Up @@ -355,7 +369,7 @@ def learn_mdp_and_strategy(mdp_model_path, prism_model_path, prism_adv_path, pri
automaton_type='smm', n_c=20, n_resample=1000, min_rounds=20, max_rounds=240,
strategy='normal', cex_processing='longest_prefix', stopping_based_on_prop=None,
target_unambiguity=0.99, eq_num_steps=2000,
smc_max_exec=5000, smc_statistical_test_bound=0.025,
smc_max_exec=5000, smc_statistical_test_bound=0.025, eq_test_initial_reset_prob=0.25,
only_classical_equivalence_testing=False,
samples_cex_strategy=None, output_dir='results', save_files_for_each_round=False,
debug=False):
Expand All @@ -367,7 +381,7 @@ def learn_mdp_and_strategy(mdp_model_path, prism_model_path, prism_adv_path, pri
return learn_mdp_and_strategy_from_sul(sul, input_alphabet, prism_model_path, prism_adv_path, prism_prop_path,
ltl_prop_path, automaton_type, n_c, n_resample, min_rounds, max_rounds,
strategy, cex_processing, stopping_based_on_prop, target_unambiguity,
eq_num_steps, smc_max_exec, smc_statistical_test_bound,
eq_num_steps, smc_max_exec, smc_statistical_test_bound, eq_test_initial_reset_prob,
only_classical_equivalence_testing, samples_cex_strategy, output_dir,
save_files_for_each_round, debug)

Expand All @@ -376,14 +390,21 @@ def learn_mdp_and_strategy_from_sul(sul, input_alphabet, prism_model_path, prism
ltl_prop_path, automaton_type='smm', n_c=20, n_resample=1000, min_rounds=20,
max_rounds=240, strategy='normal', cex_processing='longest_prefix',
stopping_based_on_prop=None, target_unambiguity=0.99, eq_num_steps=2000,
smc_max_exec=5000, smc_statistical_test_bound=0.025,
smc_max_exec=5000, smc_statistical_test_bound=0.025, eq_test_initial_reset_prob=0.25,
only_classical_equivalence_testing=False, samples_cex_strategy=None,
output_dir='results', save_files_for_each_round=False, debug=False):
logging.info(f'min_rounds: {min_rounds}')
logging.info(f'max_rounds: {max_rounds}')
logging.info(f'smc_statistical_test_bound: {smc_statistical_test_bound}')
logging.info(f'eq_test_initial_reset_prob: {eq_test_initial_reset_prob}')

eq_oracle = ProbBBReachOracle(prism_model_path, prism_adv_path, prism_prop_path, ltl_prop_path, input_alphabet,
sul=sul, smc_max_exec=smc_max_exec, statistical_test_bound=smc_statistical_test_bound,
only_classical_equivalence_testing=only_classical_equivalence_testing,
num_steps=eq_num_steps, reset_prob=0.25, reset_after_cex=True, output_dir=output_dir,
save_files_for_each_round=save_files_for_each_round, debug=debug)
num_steps=eq_num_steps, initial_reset_prob=eq_test_initial_reset_prob,
reset_after_cex=True,
output_dir=output_dir, save_files_for_each_round=save_files_for_each_round,
debug=debug)
# EQOracleChain
print_level = 2
if debug:
Expand All @@ -397,12 +418,15 @@ def learn_mdp_and_strategy_from_sul(sul, input_alphabet, prism_model_path, prism

learned_strategy = eq_oracle.learned_strategy

sb = StrategyBridge(prism_adv_path, eq_oracle.exportstates_path, eq_oracle.exporttrans_path,
eq_oracle.exportlabels_path)
smc: StatisticalModelChecker = StatisticalModelChecker(sul, sb, ltl_prop_path, 0, None, num_exec=5000,
returnCEX=False)
smc.run()
logging.info(
f'SUT value by final SMC with {smc.num_exec} executions: {smc.exec_count_satisfication / smc.num_exec}')
if learned_strategy is None:
logging.info('No strategy is learned')
else:
sb = StrategyBridge(prism_adv_path, eq_oracle.exportstates_path, eq_oracle.exporttrans_path,
eq_oracle.exportlabels_path)
smc: StatisticalModelChecker = StatisticalModelChecker(sul, sb, ltl_prop_path, 0, None, num_exec=5000,
returnCEX=False)
smc.run()
logging.info(
f'SUT value by final SMC with {smc.num_exec} executions: {smc.exec_count_satisfication / smc.num_exec}')

return learned_mdp, learned_strategy

0 comments on commit 50a56c9

Please sign in to comment.