Skip to content

Commit

Permalink
assign racks to different threads for breadth-first sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-k committed Apr 21, 2024
1 parent 295e6bf commit 03ef0f0
Showing 1 changed file with 66 additions and 43 deletions.
109 changes: 66 additions & 43 deletions src/main_leave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,10 @@ fn generate_autoplay_logs<const WRITE_LOGS: bool, const SUMMARIZE: bool, const B
} else {
Vec::new()
};
let mut thread_skip = 0;
let mut thread_lo = 0;
let mut thread_hi = 0;
let big = 1u64 << 62;
let big_plus_one = big + 1;
if BREADTH {
// calculate how many possible racks there are.
// there is a faster formula for this, but iterating is "only" ~60ms.
Expand All @@ -583,43 +586,58 @@ fn generate_autoplay_logs<const WRITE_LOGS: bool, const SUMMARIZE: bool, const B
},
0,
);
// temporarily prefill the first (thread_id / num_thread) racks to help the different threads cooperate.
// avoid overflow by dividing first.
thread_skip = num_racks / num_threads; // floor division.
thread_skip = (thread_skip * thread_id)
+ (num_racks - thread_skip * num_threads) * thread_id / num_threads;
if thread_skip > 0 {
// since the iteration order is deterministic, this can be undone.
generate_exchanges_abortable(
&mut ExchangeAbortableEnv {
found_exchange_move: |rack_bytes: &[u8]| -> bool {
if thread_full_rack_map.len() >= thread_skip {
return false;
}
thread_lo = num_racks / num_threads; // floor division.
thread_hi = (thread_lo * (thread_id + 1))
+ (num_racks - thread_lo * num_threads) * (thread_id + 1) / num_threads;
thread_lo = (thread_lo * thread_id)
+ (num_racks - thread_lo * num_threads) * thread_id / num_threads;
// temporarily prefill racks as follows. (these are exclusive ranges.)
// racks 0..thread_lo: big + 1.
// racks thread_lo..thread_hi: 0.
// racks thread_hi..num_racks: big.
// this thread would focud on filling in its assigned region first.
// if it's all unavailable, it would pick the next rack (wraps around).
let mut rack_idx = 0;
generate_exchanges(
&mut ExchangeEnv {
found_exchange_move: |rack_bytes: &[u8]| {
if rack_idx < thread_lo {
thread_full_rack_map.insert(
rack_bytes.into(),
Cumulate {
equity: 0.0,
count: 1,
count: big_plus_one,
},
);
true
},
rack_tally: &mut alphabet_freqs,
min_len: game_config.rack_size(),
max_len: game_config.rack_size(),
exchange_buffer: &mut exchange_buffer,
} else if rack_idx >= thread_hi {
thread_full_rack_map.insert(
rack_bytes.into(),
Cumulate {
equity: 0.0,
count: big,
},
);
}
rack_idx += 1;
},
0,
);
rack_tally: &mut alphabet_freqs,
min_len: game_config.rack_size(),
max_len: game_config.rack_size(),
exchange_buffer: &mut exchange_buffer,
},
0,
);
let num_thread_hi = num_racks - thread_hi;
num_racks = thread_hi - thread_lo;
if thread_lo > 0 {
thread_sample_count_map.insert(big_plus_one, thread_lo);
}
assert_eq!(thread_skip, thread_full_rack_map.len());
num_racks -= thread_skip;
if num_racks > 0 {
thread_sample_count_map.insert(0, num_racks);
}
if thread_skip > 0 {
thread_sample_count_map.insert(1, thread_skip);
if num_thread_hi > 0 {
thread_sample_count_map.insert(big, num_thread_hi);
}
}
loop {
Expand Down Expand Up @@ -974,31 +992,36 @@ fn generate_autoplay_logs<const WRITE_LOGS: bool, const SUMMARIZE: bool, const B

if BREADTH {
// undo the prefill to get the real stats.
if thread_skip > 0 {
generate_exchanges_abortable(
&mut ExchangeAbortableEnv {
found_exchange_move: |rack_bytes: &[u8]| -> bool {
if thread_skip == 0 {
return false;
let mut rack_idx = 0;
generate_exchanges(
&mut ExchangeEnv {
found_exchange_move: |rack_bytes: &[u8]| {
if rack_idx < thread_lo {
if let std::collections::hash_map::Entry::Occupied(mut entry) =
thread_full_rack_map.entry(rack_bytes.into())
{
entry.get_mut().count -= big_plus_one;
} else {
unreachable!();
}
} else if rack_idx >= thread_hi {
if let std::collections::hash_map::Entry::Occupied(mut entry) =
thread_full_rack_map.entry(rack_bytes.into())
{
entry.get_mut().count -= 1;
thread_skip -= 1;
entry.get_mut().count -= big;
} else {
unreachable!();
}
true
},
rack_tally: &mut alphabet_freqs,
min_len: game_config.rack_size(),
max_len: game_config.rack_size(),
exchange_buffer: &mut exchange_buffer,
}
rack_idx += 1;
},
0,
);
}
rack_tally: &mut alphabet_freqs,
min_len: game_config.rack_size(),
max_len: game_config.rack_size(),
exchange_buffer: &mut exchange_buffer,
},
0,
);
}

let batched_csv_log_buf = batched_csv_log.into_inner().unwrap();
Expand Down

0 comments on commit 03ef0f0

Please sign in to comment.