diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 6b732e6e..aca2da66 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -642,17 +642,21 @@ def run( stuck = [] thread_pool = ThreadPoolExecutor(max_workers=args.solver_threads) - result_exs = [] future_models = [] counterexamples = [] unsat_cores = [] - traces = {} + traces: dict[int, str] = {} + exec_cache: dict[int, Exec] = {} def future_callback(future_model): m = future_model.result() models.append(m) model, index, result = m.model, m.index, m.result + + # retrieve cached exec and clear the cache entry + exec = exec_cache.pop(index, None) + if result == unsat: if m.unsat_core: unsat_cores.append(m.unsat_core) @@ -672,20 +676,25 @@ def future_callback(future_model): else: warn_code(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}") - if args.print_failed_states: - print(f"# {idx+1}") - print(result_exs[index]) - if args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE: print( - f"Trace #{idx+1}:" + f"Trace #{index + 1}:" if args.verbose == VERBOSITY_TRACE_PATHS else "Trace:" ) print(traces[index], end="") + if args.print_failed_states: + print(f"# {index + 1}") + print(exec) + + # initialize with default value in case we don't enter the loop body + idx = -1 + for idx, ex in enumerate(exs): - result_exs.append(ex) + # cache exec in case we need to print it later + if args.print_failed_states: + exec_cache[idx] = ex if args.verbose >= VERBOSITY_TRACE_PATHS: print(f"Path #{idx+1}:") @@ -725,15 +734,21 @@ def future_callback(future_model): print(ex) normal += 1 + # print post-states + if args.print_states: + print(f"# {idx+1}") + print(ex) + # 0 width is unlimited - if args.width and len(result_exs) >= args.width: + if args.width and idx >= args.width: break + num_execs = idx + 1 timer.create_subtimer("models") - if len(future_models) > 0 and args.verbose >= 1: + if future_models and args.verbose >= 1: print( - f"# of potential paths involving assertion violations: {len(future_models)} / {len(result_exs)} (--solver-threads {args.solver_threads})" + f"# of potential paths involving assertion violations: {len(future_models)} / {num_execs} (--solver-threads {args.solver_threads})" ) # display assertion solving progress @@ -781,7 +796,7 @@ def future_callback(future_model): # print result print( - f"{passfail} {funsig} (paths: {len(result_exs)}, {time_info}, bounds: [{', '.join([str(x) for x in dyn_params])}])" + f"{passfail} {funsig} (paths: {num_execs}, {time_info}, bounds: [{', '.join([str(x) for x in dyn_params])}])" ) for idx, _, err in stuck: @@ -797,12 +812,6 @@ def future_callback(future_model): ) debug("\n".join(jumpid_str(x) for x in logs.bounded_loops)) - # print post-states - if args.print_states: - for idx, ex in enumerate(result_exs): - print(f"# {idx+1} / {len(result_exs)}") - print(ex) - # log steps if args.log: with open(args.log, "w") as json_file: @@ -817,7 +826,7 @@ def future_callback(future_model): exitcode, len(counterexamples), counterexamples, - (len(result_exs), normal, len(stuck)), + (num_execs, normal, len(stuck)), (timer.elapsed(), timer["paths"].elapsed(), timer["models"].elapsed()), len(logs.bounded_loops), ) @@ -1352,6 +1361,11 @@ def _main(_args=None) -> MainResult: logger.setLevel(logging.DEBUG) logger_unique.setLevel(logging.DEBUG) + if args.trace_memory: + import halmos.memtrace as memtrace + + memtrace.MemTracer.get().start() + # # compile # diff --git a/src/halmos/config.py b/src/halmos/config.py index fd88090e..2d28e70f 100644 --- a/src/halmos/config.py +++ b/src/halmos/config.py @@ -394,6 +394,12 @@ class Config: group=debugging, ) + trace_memory: bool = arg( + help="trace memory allocations and deallocations", + global_default=False, + group=debugging, + ) + ### Build options forge_build_out: str = arg( @@ -787,7 +793,7 @@ def _to_toml_str(value: Any, type) -> str: continue name = field_info.name.replace("_", "-") - if name in ["config", "root", "version"]: + if name in ["config", "root", "version", "trace_memory"]: # skip fields that don't make sense in a config file continue diff --git a/src/halmos/memtrace.py b/src/halmos/memtrace.py new file mode 100644 index 00000000..4fe0ce43 --- /dev/null +++ b/src/halmos/memtrace.py @@ -0,0 +1,181 @@ +import io +import linecache +import threading +import time +import tracemalloc + +from rich.console import Console + +from halmos.logs import debug + +console = Console() + + +def readable_size(num: int | float) -> str: + if num < 1024: + return f"{num}B" + + if num < 1024 * 1024: + return f"{num/1024:.1f}KiB" + + return f"{num/(1024*1024):.1f}MiB" + + +def pretty_size(num: int | float) -> str: + return f"[magenta]{readable_size(num)}[/magenta]" + + +def pretty_count_diff(num: int | float) -> str: + if num > 0: + return f"[red]+{num}[/red]" + elif num < 0: + return f"[green]{num}[/green]" + else: + return "[gray]0[/gray]" + + +def pretty_line(line: str): + return f"[white] {line}[/white]" if line else "" + + +def pretty_frame_info( + frame: tracemalloc.Frame, result_number: int | None = None +) -> str: + result_number_str = ( + f"[grey37]# {result_number+1}:[/grey37] " if result_number is not None else "" + ) + filename_str = f"[grey37]{frame.filename}:[/grey37]" + lineno_str = f"[grey37]{frame.lineno}:[/grey37]" + return f"{result_number_str}{filename_str}{lineno_str}" + + +class MemTracer: + curr_snapshot: tracemalloc.Snapshot | None = None + prev_snapshot: tracemalloc.Snapshot | None = None + running: bool = False + + _instance = None + _lock = threading.Lock() + + def __init__(self): + if MemTracer._instance is not None: + raise RuntimeError("Use MemTracer.get() to access the singleton instance.") + self.curr_snapshot = None + self.prev_snapshot = None + self.running = False + + @classmethod + def get(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def take_snapshot(self): + debug("memtracer: taking snapshot") + self.prev_snapshot = self.curr_snapshot + self.curr_snapshot = tracemalloc.take_snapshot() + self.display_stats() + + def display_stats(self): + """Display statistics about the current memory snapshot.""" + if not self.running: + return + + if self.curr_snapshot is None: + debug("memtracer: no current snapshot") + return + + out = io.StringIO() + + # Show top memory consumers by line + out.write("[cyan][ Top memory consumers ][/cyan]\n") + stats = self.curr_snapshot.statistics("lineno") + for i, stat in enumerate(stats[:10]): + frame = stat.traceback[0] + line = linecache.getline(frame.filename, frame.lineno).strip() + out.write(f"{pretty_frame_info(frame, i)} " f"{pretty_size(stat.size)}\n") + out.write(f"{pretty_line(line)}\n") + out.write("\n") + + # Get total memory usage + total = sum(stat.size for stat in self.curr_snapshot.statistics("filename")) + out.write(f"Total memory used in snapshot: {pretty_size(total)}\n\n") + + console.print(out.getvalue()) + + def start(self, interval_seconds=60): + """Start tracking memory usage at the specified interval.""" + if not tracemalloc.is_tracing(): + nframes = 1 + tracemalloc.start(nframes) + self.running = True + + self.take_snapshot() + threading.Thread( + target=self._run, args=(interval_seconds,), daemon=True + ).start() + + def stop(self): + """Stop the memory tracer.""" + self.running = False + + def _run(self, interval_seconds): + """Run the tracer periodically.""" + while self.running: + time.sleep(interval_seconds) + self.take_snapshot() + self._display_differences() + + def _display_differences(self): + """Display top memory differences between snapshots.""" + + if not self.running: + return + + if self.prev_snapshot is None or self.curr_snapshot is None: + debug("memtracer: no snapshots to compare") + return + + out = io.StringIO() + + top_stats = self.curr_snapshot.compare_to( + self.prev_snapshot, "lineno", cumulative=True + ) + out.write("[cyan][ Top differences ][/cyan]\n") + for i, stat in enumerate(top_stats[:10]): + frame = stat.traceback[0] + line = linecache.getline(frame.filename, frame.lineno).strip() + out.write( + f"{pretty_frame_info(frame, i)} " + f"{pretty_size(stat.size_diff)} " + f"[{pretty_count_diff(stat.count_diff)}]\n" + ) + out.write(f"{pretty_line(line)}\n") + + total_diff = sum(stat.size_diff for stat in top_stats) + out.write(f"Total size difference: {pretty_size(total_diff)}\n") + + console.print(out.getvalue()) + + +def main(): + tracer = MemTracer.get() + tracer.start(interval_seconds=2) + + # Simulate some workload + import random + + memory_hog = [] + try: + while True: + memory_hog.append([random.random() for _ in range(1000)]) + time.sleep(0.1) + except KeyboardInterrupt: + # Stop the tracer on exit + tracer.stop() + + +if __name__ == "__main__": + main()