From bd65938a713b2c79856416deb4991098c13f9c4f Mon Sep 17 00:00:00 2001 From: Raj Prateek Kosaraju Date: Thu, 20 May 2021 15:32:44 -0700 Subject: [PATCH] Sweep code for studying model population stats (1 of 2) (#143) Summary: This is a *major update* and introduces powerful new functionality to pycls. The pycls codebase now provides powerful support for studying *design spaces* and more generally *population statistics* of models as introduced in [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214) and [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678). This idea is that instead of planning a single pycls job (e.g., testing a specific model configuration), one can study the behavior of an entire population of models. This allows for quite powerful and succinct experimental design, and elevates the study of individual model behavior to the study of the behavior of model populations. Please see [`SWEEP_INFO`](docs/SWEEP_INFO.md) for details. This is commit 1 of 2 for the sweep code. It is focused on the sweep config, setting up the sweep, and launching it. Pull Request resolved: https://github.com/facebookresearch/pycls/pull/143 Reviewed By: pdollar Differential Revision: D28580825 Pulled By: rajprateek fbshipit-source-id: 9221f0a9b3651642d2c6acd87befacd6521825cc Co-authored-by: Raj Prateek Kosaraju Co-authored-by: Piotr Dollar <699682+pdollar@users.noreply.github.com> --- pycls/sweep/__init__.py | 0 pycls/sweep/config.py | 246 ++++++++++++++++++++++++++++++++++++++ pycls/sweep/random.py | 120 +++++++++++++++++++ pycls/sweep/samplers.py | 111 +++++++++++++++++ tools/sweep_collect.py | 89 ++++++++++++++ tools/sweep_launch.py | 108 +++++++++++++++++ tools/sweep_launch_job.py | 122 +++++++++++++++++++ tools/sweep_setup.py | 126 +++++++++++++++++++ 8 files changed, 922 insertions(+) create mode 100644 pycls/sweep/__init__.py create mode 100644 pycls/sweep/config.py create mode 100644 pycls/sweep/random.py create mode 100644 pycls/sweep/samplers.py create mode 100644 tools/sweep_collect.py create mode 100644 tools/sweep_launch.py create mode 100644 tools/sweep_launch_job.py create mode 100644 tools/sweep_setup.py diff --git a/pycls/sweep/__init__.py b/pycls/sweep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pycls/sweep/config.py b/pycls/sweep/config.py new file mode 100644 index 0000000..7ab884a --- /dev/null +++ b/pycls/sweep/config.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Configuration file (powered by YACS).""" + +import argparse +import getpass +import multiprocessing +import os +import sys + +from pycls.core.config import cfg +from pycls.sweep.samplers import validate_sampler +from yacs.config import CfgNode as CfgNode + + +# Example usage: from sweep.config import sweep_cfg +sweep_cfg = _C = CfgNode() + + +# ------------------------------ General sweep options ------------------------------- # + +# Sweeps root directory where all sweep output subdirectories will be placed +_C.ROOT_DIR = "/checkpoint/{}/sweeps/".format(getpass.getuser()) + +# Sweep name must be unique per sweep and defines the output subdirectory +_C.NAME = "" + +# Optional description of a sweep useful to keep track of sweeps +_C.DESC = "" + +# Number of processes to use for various sweep steps except for running jobs +_C.NUM_PROC = multiprocessing.cpu_count() + +# Automatically overwritten to the file from which the sweep_cfg is loaded +_C.SWEEP_CFG_FILE = "" + + +# ------------------------------- Sweep setup options -------------------------------- # +_C.SETUP = CfgNode() + +# Max number of unique job configs to generate +_C.SETUP.NUM_CONFIGS = 0 + +# Max number of attempts for generating NUM_CONFIGS valid configs +_C.SETUP.NUM_SAMPLES = 1000000 + +# Specifies the chunk size to use per process while sampling configs +_C.SETUP.CHUNK_SIZE = 5000 + +# Random seed for generating job configs +_C.SETUP.RNG_SEED = 0 + +# Base config for all jobs, any valid config option in core.config is valid here +_C.SETUP.BASE_CFG = cfg.clone() + +# Samplers to use for generating job configs (see SAMPLERS defined toward end of file) +# SETUP.SAMPLERS should consists of a dictionary of SAMPLERS +# Each dict key should be a valid parameter in the BASE_CFG (e.g. "MODEL.DEPTH") +# Each dict val should be a valid SAMPLER that defines how to sample (e.g. INT_SAMPLER) +# See the example sweep configs for more usage information +_C.SETUP.SAMPLERS = CfgNode(new_allowed=True) + +# Constraints on generated configs +_C.SETUP.CONSTRAINTS = CfgNode() + +# Complexity constraints CX on models specified as a [LOW, HIGH] range, e.g. [0, 1.0e+6] +# If LOW == HIGH == 0 for a given complexity constraint that constraint is not applied +# For RegNets, if flops= q else f + return f + + +def normal(low, high, q): + """Samples values from a clipped normal (Gaussian) distribution quantized to q.""" + # mu/sigma are computed from low/high such that ~99.7% of samples are in range + f, mu, sigma = np.inf, (low + high) / 2, (high - low) / 6 + while not low <= f <= high: + f = np.random.normal(mu, sigma) + return quantize(f, q, np.round) + + +def log_normal(low, high, q): + """Samples values from a clipped log-normal distribution quantized to q.""" + # mu/sigma are computed from low/high such that ~99.7% of samples are in range + log_low, log_high = np.log(low), np.log(high) + f, mu, sigma = np.inf, (log_low + log_high) / 2, (log_high - log_low) / 6 + while not low <= f <= high: + f = np.random.lognormal(mu, sigma) + return quantize(f, q, np.round) + + +rand_types = { + "uniform": uniform, + "log_uniform": log_uniform, + "power2_uniform": power2_uniform, + "power2_or_log_uniform": power2_or_log_uniform, + "normal": normal, + "log_normal": log_normal, +} + + +def validate_rand(err_str, rand_type, low, high, q): + """Validate parameters to random number generators.""" + err_msg = "{}: {}(low={}, high={}, q={}) is invalid." + err_msg = err_msg.format(err_str, rand_type, low, high, q) + low_q = Decimal(str(low)) % Decimal(str(q)) == 0 + high_q = Decimal(str(high)) % Decimal(str(q)) == 0 + assert type(q) == type(low) == type(high), err_msg + assert rand_type in rand_types, err_msg + assert q > 0 and low <= high, err_msg + assert low > 0 or rand_type in ["uniform", "normal"], err_msg + assert low_q and high_q or rand_type == "power2_or_log_uniform", err_msg + if rand_type in ["power2_uniform", "power2_or_log_uniform"]: + assert all(np.log2(v).is_integer() for v in [low, high, q]), err_msg + + +def plot_rand_cdf(rand_type, low, high, q, n=10000): + """Visualizes CDF of rand_fun, resulting CDF should be linear (or log-linear).""" + validate_rand("plot_rand_cdf", rand_type, low, high, q) + samples = [rand_types[rand_type](low, high, q) for _ in range(n)] + unique = list(np.unique(samples)) + assert min(unique) >= low and max(unique) <= high, "Sampled value out of range." + cdf = np.cumsum(np.histogram(samples, unique + [np.inf])[0]) / len(samples) + plot_fun = plt.plot if rand_type in ["uniform", "normal"] else plt.semilogx + plot_fun(unique, cdf, "o-", [low, low], [0, 1], "-k", [high, high], [0, 1], "-k") + plot_fun([low, high], [cdf[0], cdf[-1]]) if "normal" not in rand_type else () + plt.title("{}(low={}, high={}, q={})".format(rand_type, low, high, q)) + plt.show() + + +def plot_rand_cdfs(): + """Visualize CDFs of selected distributions, for visualization/debugging only.""" + plot_rand_cdf("uniform", -0.5, 0.5, 0.1) + plot_rand_cdf("power2_uniform", 2, 512, 1) + plot_rand_cdf("power2_uniform", 0.25, 8.0, 0.25) + plot_rand_cdf("log_uniform", 1, 32, 1) + plot_rand_cdf("log_uniform", 0.5, 16.0, 0.5) + plot_rand_cdf("power2_or_log_uniform", 1.0, 16.0, 1.0) + plot_rand_cdf("power2_or_log_uniform", 0.25, 4.0, 4.0) + plot_rand_cdf("power2_or_log_uniform", 1, 128, 4) diff --git a/pycls/sweep/samplers.py b/pycls/sweep/samplers.py new file mode 100644 index 0000000..af00798 --- /dev/null +++ b/pycls/sweep/samplers.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Sweep sampling functions.""" + +import numpy as np +import pycls.core.builders as builders +import pycls.core.net as net +import pycls.models.regnet as regnet +import pycls.sweep.random as rand +from pycls.core.config import cfg +from pycls.sweep.random import validate_rand + + +def scalar_sampler(sampler): + """Sampler for scalars in RANGE quantized to QUANTIZE.""" + low, high = sampler.RANGE[0], sampler.RANGE[1] + rand_fun, q = rand.rand_types[sampler.RAND_TYPE], sampler.QUANTIZE + return rand_fun(low, high, q) + + +def value_sampler(sampler): + """Sampler for uniform sampling from a list of values.""" + rand_index = np.random.randint(len(sampler.VALUES)) + return sampler.VALUES[rand_index] + + +def list_sampler(sampler): + """Sampler for a list of n items sampled independently by the item_sampler.""" + item_sampler, n = sampler.ITEM_SAMPLER, sampler.LENGTH + sampler_function = sampler_types[item_sampler.TYPE] + return [sampler_function(item_sampler) for _ in range(n)] + + +def regnet_sampler(sampler): + """Sampler for main RegNet parameters.""" + d = rand.uniform(*sampler.DEPTH, 1) + w0 = rand.log_uniform(*sampler.W0, 8) + wa = rand.log_uniform(*sampler.WA, 0.1) + wm = rand.log_uniform(*sampler.WM, 0.001) + gw = rand.power2_or_log_uniform(*sampler.GROUP_W, 8) + bm = rand.power2_uniform(*sampler.BOT_MUL, 1 / 128) + params = ["DEPTH", d, "W0", w0, "WA", wa, "WM", wm, "GROUP_W", gw, "BOT_MUL", bm] + return ["REGNET." + p if i % 2 == 0 else p for i, p in enumerate(params)] + + +sampler_types = { + "float_sampler": scalar_sampler, + "int_sampler": scalar_sampler, + "value_sampler": value_sampler, + "list_sampler": list_sampler, + "regnet_sampler": regnet_sampler, +} + + +def validate_sampler(param, sampler): + """Performs various checks on sampler to see if it is valid.""" + if sampler.TYPE in ["int_sampler", "float_sampler"]: + validate_rand(param, sampler.RAND_TYPE, *sampler.RANGE, sampler.QUANTIZE) + elif sampler.TYPE == "regnet_sampler": + assert param == "REGNET", "regnet_sampler can only be used for REGNET" + validate_rand("REGNET.DEPTH", "uniform", *sampler.DEPTH, 1) + validate_rand("REGNET.W0", "log_uniform", *sampler.W0, 8) + validate_rand("REGNET.WA", "log_uniform", *sampler.WA, 0.1) + validate_rand("REGNET.WM", "log_uniform", *sampler.WM, 0.001) + validate_rand("REGNET.GROUP_W", "power2_or_log_uniform", *sampler.GROUP_W, 8) + validate_rand("REGNET.BOT_MUL", "power2_uniform", *sampler.BOT_MUL, 1 / 128) + + +def is_composite_sampler(sampler_type): + """Composite samplers return a [key, val, ...] list as opposed to just a val.""" + composite_samplers = ["regnet_sampler"] + return sampler_type in composite_samplers + + +def sample_parameters(samplers): + """Samples params [key, val, ...] list based on the samplers.""" + params = [] + for param, sampler in samplers.items(): + val = sampler_types[sampler.TYPE](sampler) + is_composite = is_composite_sampler(sampler.TYPE) + params.extend(val if is_composite else [param, val]) + return params + + +def check_regnet_constraints(constraints): + """Checks RegNet specific constraints.""" + if cfg.MODEL.TYPE == "regnet": + wa, w0, wm, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH + _, _, num_s, max_s, _, _ = regnet.generate_regnet(wa, w0, wm, d, 8) + num_stages = constraints.REGNET.NUM_STAGES + if num_s != max_s or not num_stages[0] <= num_s <= num_stages[1]: + return False + return True + + +def check_complexity_constraints(constraints): + """Checks complexity constraints.""" + cx, valid = None, True + for p, v in constraints.CX.items(): + p, min_v, max_v = p.lower(), v[0], v[1] + if min_v != 0 or max_v != 0: + cx = cx if cx else net.complexity(builders.get_model()) + min_v = cx[p] if min_v == 0 else min_v + max_v = cx[p] if max_v == 0 else max_v + valid = valid and (min_v <= cx[p] <= max_v) + return valid diff --git a/tools/sweep_collect.py b/tools/sweep_collect.py new file mode 100644 index 0000000..b649a84 --- /dev/null +++ b/tools/sweep_collect.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +"""Collect results of a sweep.""" + +import functools +import json +import multiprocessing +import os + +import pycls.core.checkpoint as cp +import pycls.core.logging as logging +import pycls.sweep.config as sweep_config +from pycls.sweep.config import sweep_cfg + + +# Skip over these data types as they make sweep logs too large +_DATA_TYPES_TO_SKIP = ["train_iter", "test_iter"] + + +def load_data(log_file): + """Loads and sorts log data or returns None.""" + data = logging.load_log_data(log_file, _DATA_TYPES_TO_SKIP) + data = logging.sort_log_data(data) + err_file = log_file.replace("stdout.log", "stderr.log") + data["log_file"] = log_file + data["err_file"] = err_file + with open(err_file, "r") as f: + data["err"] = f.read() + return data + + +def sweep_collect(): + """Collects results of a sweep.""" + # Get cfg and log files + sweep_dir = os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME) + print("Collecting jobs for {:s}... ".format(sweep_dir)) + cfgs_dir = os.path.join(sweep_dir, "cfgs") + logs_dir = os.path.join(sweep_dir, "logs") + assert os.path.exists(cfgs_dir), "Cfgs dir {} not found".format(cfgs_dir) + assert os.path.exists(logs_dir), "Logs dir {} not found".format(logs_dir) + cfg_files = [c for c in os.listdir(cfgs_dir) if c.endswith(".yaml")] + log_files = logging.get_log_files(logs_dir)[0] + # Create worker pool for collecting jobs + process_pool = multiprocessing.Pool(sweep_cfg.NUM_PROC) + # Load the sweep and keep only non-empty data + print("Collecting jobs...") + sweep = list(process_pool.map(load_data, log_files)) + # Print basic stats for sweep status + key = "test_epoch" + epoch_ind = [d[key]["epoch_ind"][-1] if key in d else 0 for d in sweep] + epoch_max = [d[key]["epoch_max"][-1] if key in d else 1 for d in sweep] + epoch = ["{}/{}".format(i, m) for i, m in zip(epoch_ind, epoch_max)] + epoch = [e.ljust(len(max(epoch, key=len))) for e in epoch] + job_done = sum(i == m for i, m in zip(epoch_ind, epoch_max)) + for d, e, i, m in zip(sweep, epoch, epoch_ind, epoch_max): + out_str = " {} [{:3d}%] [{:}]" + (" [stderr]" if d["err"] else "") + print(out_str.format(d["log_file"], int(i / m * 100), e)) + jobs_start = "jobs_started={}/{}".format(len(sweep), len(cfg_files)) + jobs_done = "jobs_done={}/{}".format(job_done, len(cfg_files)) + ep_done = "epochs_done={}/{}".format(sum(epoch_ind), sum(epoch_max)) + print("Status: {}, {}, {}".format(jobs_start, jobs_done, ep_done)) + # Save the sweep data + sweep_file = os.path.join(sweep_dir, "sweep.json") + print("Writing sweep data to: {}".format(sweep_file)) + with open(sweep_file, "w") as f: + json.dump(sweep, f, sort_keys=True) + # Clean up checkpoints after saving sweep data, if needed + keep = sweep_cfg.COLLECT.CHECKPOINTS_KEEP + cp_dirs = [f.replace("stdout.log", "checkpoints/") for f in log_files] + delete_cps = functools.partial(cp.delete_checkpoints, keep=keep) + num_cleaned = sum(process_pool.map(delete_cps, cp_dirs)) + print("Deleted {} total checkpoints".format(num_cleaned)) + + +def main(): + desc = "Collect results of a sweep." + sweep_config.load_cfg_fom_args(desc) + sweep_cfg.freeze() + sweep_collect() + + +if __name__ == "__main__": + main() diff --git a/tools/sweep_launch.py b/tools/sweep_launch.py new file mode 100644 index 0000000..2ea2d7e --- /dev/null +++ b/tools/sweep_launch.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Launch sweep on a SLURM managed cluster.""" + +import os + +import pycls.sweep.config as sweep_config +from pycls.sweep.config import sweep_cfg + + +_SBATCH_CMD = ( + "sbatch" + " --job-name={name}" + " --partition={partition}" + " --gpus={num_gpus}" + " --constraint={gpu_type}" + " --mem={mem}GB" + " --cpus-per-task={cpus}" + " --array=0-{last_job}%{parallel_jobs}" + " --output={sweep_dir}/logs/sbatch/%A_%a.out" + " --error={sweep_dir}/logs/sbatch/%A_%a.out" + " --time={time_limit}" + ' --comment="{comment}"' + " --signal=B:USR1@300" + " --nodes=1" + " --open-mode=append" + " --ntasks-per-node=1" + " {current_dir}/sweep_launch_job.py" + " --conda-env {conda_env}" + " --script-path {script_path}" + " --cfgs-dir {cfgs_dir}" + " --pycls-dir {pycls_dir}" + " --logs-dir {logs_dir}" + " --max-retry {max_retry}" +) + + +def sweep_launch(): + """Launch sweep on a SLURM managed cluster.""" + launch_cfg = sweep_cfg.LAUNCH + # Get and check directory and script locations + current_dir = os.path.dirname(os.path.abspath(__file__)) + sweep_dir = os.path.abspath(os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME)) + cfgs_dir = os.path.join(sweep_dir, "cfgs") + logs_dir = os.path.join(sweep_dir, "logs") + sbatch_dir = os.path.join(logs_dir, "sbatch") + script_path = os.path.abspath(launch_cfg.SCRIPT) + assert os.path.exists(sweep_dir), "Sweep dir {} invalid".format(sweep_dir) + assert os.path.exists(script_path), "Script path {} invalid".format(script_path) + n_cfgs = len([c for c in os.listdir(cfgs_dir) if c.endswith(".yaml")]) + # Replace path to be relative to copy of pycls + pycls_copy_dir = os.path.join(sweep_dir, "pycls") + pycls_dir = os.path.abspath(os.path.join(current_dir, "..")) + script_path = script_path.replace(pycls_dir, pycls_copy_dir) + current_dir = current_dir.replace(pycls_dir, pycls_copy_dir) + # Prepare command to copy pycls to sweep_dir/pycls + cmd_to_copy_pycls = "cp -R {}/ {}".format(pycls_dir, pycls_copy_dir) + print("Cmd to copy pycls:", cmd_to_copy_pycls) + # Prepare launch command + cmd_to_launch_sweep = _SBATCH_CMD.format( + name=sweep_cfg.NAME, + partition=launch_cfg.PARTITION, + num_gpus=launch_cfg.NUM_GPUS, + gpu_type=launch_cfg.GPU_TYPE, + mem=launch_cfg.MEM_PER_GPU * launch_cfg.NUM_GPUS, + cpus=launch_cfg.CPUS_PER_GPU * launch_cfg.NUM_GPUS, + last_job=n_cfgs - 1, + parallel_jobs=launch_cfg.PARALLEL_JOBS, + time_limit=launch_cfg.TIME_LIMIT, + comment=launch_cfg.COMMENT, + sweep_dir=sweep_dir, + current_dir=current_dir, + conda_env=launch_cfg.CONDA_ENV, + script_path=script_path, + cfgs_dir=cfgs_dir, + pycls_dir=pycls_copy_dir, + logs_dir=logs_dir, + max_retry=launch_cfg.MAX_RETRY, + ) + print("Cmd to launch sweep:", cmd_to_launch_sweep.replace(" ", "\n "), sep="\n\n") + # Prompt user to resume or launch sweep + if os.path.exists(sbatch_dir): + print("\nSweep exists! Relaunch ONLY if no jobs are running!") + print("\nRelaunch sweep? [relaunch/n]") + if input().lower() == "relaunch": + os.system(cmd_to_launch_sweep) + else: + print("\nLaunch sweep? [y/n]") + if input().lower() == "y": + os.makedirs(sbatch_dir, exist_ok=False) + os.system(cmd_to_copy_pycls) + os.system(cmd_to_launch_sweep) + + +def main(): + desc = "Launch a sweep on the cluster." + sweep_config.load_cfg_fom_args(desc) + sweep_cfg.freeze() + sweep_launch() + + +if __name__ == "__main__": + main() diff --git a/tools/sweep_launch_job.py b/tools/sweep_launch_job.py new file mode 100644 index 0000000..313ba5e --- /dev/null +++ b/tools/sweep_launch_job.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Launch a job on SLURM managed cluster. Should only be called from sweep_launch.py""" + +import argparse +import json +import os +import signal +import subprocess +import sys +from datetime import datetime + + +def prt(*args, **kwargs): + """Wrapper for print that prepends a timestamp and flushes output.""" + print("[{}]".format(str(datetime.now())), *args, flush=True, **kwargs) + + +def run_os_cmd(cmd): + """Runs commands in bash environment in foreground.""" + os.system('bash -c "{}"'.format(cmd)) + + +def requeue_job(): + job_id = os.environ["SLURM_ARRAY_JOB_ID"] + task_id = os.environ["SLURM_ARRAY_TASK_ID"] + cmd_to_req = "scontrol requeue {}_{}".format(job_id, task_id) + prt("Requeuing job using cmd: {}".format(cmd_to_req)) + os.system(cmd_to_req) + prt("Requeued job {}. Exiting.\n\n".format(job_id)) + sys.exit(0) + + +def sigusr1_handler(signum, _): + """Handles SIGUSR1 that is sent before a job is killed by requeuing it.""" + prt("Caught SIGUSR1 with code {}".format(signum)) + requeue_job() + + +def sigterm_handler(signum, _): + """Handles SIGTERM that is sent before a job is preempted by bypassing it.""" + prt("Caught SIGTERM with code {}".format(signum)) + prt("Bypassing SIGTERM") + + +def main(): + # Parse arguments + desc = "Launch a job on SLURM cluster. Should only be called from sweep_launch.py" + parser = argparse.ArgumentParser(description=desc) + parser.add_argument("--conda-env", required=True) + parser.add_argument("--script-path", required=True) + parser.add_argument("--cfgs-dir", required=True) + parser.add_argument("--pycls-dir", required=True) + parser.add_argument("--logs-dir", required=True) + parser.add_argument("--max-retry", required=True, type=int) + args = parser.parse_args() + prt("Called with args: {}".format(args)) + # Attach signal handlers for SIGUSR1 and SIGTERM + signal.signal(signal.SIGUSR1, sigusr1_handler) + signal.signal(signal.SIGTERM, sigterm_handler) + # Print info about run + job_id = os.environ["SLURM_ARRAY_JOB_ID"] + task_id = os.environ["SLURM_ARRAY_TASK_ID"] + prt("Job array master job ID: {}".format(job_id)) + prt("Job array task ID (index): {}".format(task_id)) + prt("Running job on: {}".format(str(os.uname()))) + # Load what we need + run_os_cmd("module purge") + run_os_cmd("module load anaconda3") + run_os_cmd("source deactivate") + run_os_cmd("source activate {}".format(args.conda_env)) + # Get cfg_file to use + cfg_files = sorted(f for f in os.listdir(args.cfgs_dir) if f.endswith(".yaml")) + cfg_file = os.path.join(args.cfgs_dir, cfg_files[int(task_id)]) + prt("Using cfg_file: {}".format(cfg_file)) + # Create out_dir + out_dir = os.path.join(args.logs_dir, "{:06}".format(int(task_id))) + os.makedirs(out_dir, exist_ok=True) + prt("Using out_dir: {}".format(out_dir)) + # Create slurm_file with SLURM info + slurm_file = os.path.join(out_dir, "SLURM.txt") + with open(slurm_file, "a") as f: + f.write("SLURM env variables for the job writing to this directory:\n") + slurm_info = {k: os.environ[k] for k in os.environ if k.startswith("SLURM_")} + f.write(json.dumps(slurm_info, indent=4)) + prt("Dumped SLURM job info to {}".format(slurm_file)) + # Set PYTHONPATH to pycls copy for sweep + os.environ["PYTHONPATH"] = args.pycls_dir + prt("Using PYTHONPATH={}".format(args.pycls_dir)) + # Generate srun command to launch + cmd_to_run = ( + "srun" + " --output {out_dir}/stdout.log" + " --error {out_dir}/stderr.log" + " python {script}" + " --cfg {cfg}" + " OUT_DIR {out_dir}" + ).format(out_dir=out_dir, script=args.script_path, cfg=cfg_file) + prt("Running cmd:\n", cmd_to_run.replace(" ", "\n ")) + # Run command in background using subprocess and wait so that signals can be caught + p = subprocess.Popen(cmd_to_run, shell=True) + prt("Waiting for job to complete") + p.wait() + prt("Completed waiting. Return code for job: {}".format(p.returncode)) + if p.returncode != 0: + retry_file = os.path.join(out_dir, "RETRY.txt") + with open(retry_file, "a") as f: + f.write("Encountered non-zero exit code\n") + with open(retry_file, "r") as f: + retry_count = len(f.readlines()) - 1 + prt("Retry count for job: {}".format(retry_count)) + if retry_count < args.max_retry: + requeue_job() + + +if __name__ == "__main__": + main() diff --git a/tools/sweep_setup.py b/tools/sweep_setup.py new file mode 100644 index 0000000..686db1a --- /dev/null +++ b/tools/sweep_setup.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Sample cfgs for a sweep using a sweep_cfg.""" + +import multiprocessing +import os + +import numpy as np +import pycls.models.scaler as scaler +import pycls.sweep.config as sweep_config +import pycls.sweep.samplers as samplers +import yaml +from pycls.core.config import cfg, reset_cfg +from pycls.core.timer import Timer +from pycls.sweep.config import sweep_cfg + + +def sample_cfgs(seed): + """Samples chunk configs and return those that are unique and valid.""" + # Fix RNG seed (every call to this function should use a unique seed) + np.random.seed(seed) + setup_cfg = sweep_cfg.SETUP + cfgs = {} + for _ in range(setup_cfg.CHUNK_SIZE): + # Sample parameters [key, val, ...] list based on the samplers + params = samplers.sample_parameters(setup_cfg.SAMPLERS) + # Check if config is unique, if not continue + key = zip(params[0::2], params[1::2]) + key = " ".join(["{} {}".format(k, v) for k, v in key]) + if key in cfgs: + continue + # Generate config from parameters + reset_cfg() + cfg.merge_from_other_cfg(setup_cfg.BASE_CFG) + cfg.merge_from_list(params) + # Check if config is valid, if not continue + is_valid = samplers.check_regnet_constraints(setup_cfg.CONSTRAINTS) + if not is_valid: + continue + # Special logic for dealing w model scaling (side effect is to standardize cfg) + scaler.scale_model() + # Check if config is valid, if not continue + is_valid = samplers.check_complexity_constraints(setup_cfg.CONSTRAINTS) + if not is_valid: + continue + # Set config description to key + cfg.DESC = key + # Store copy of config if unique and valid + cfgs[key] = cfg.clone() + # Stop sampling if already reached quota + if len(cfgs) == setup_cfg.NUM_CONFIGS: + break + return cfgs + + +def dump_cfg(cfg_file, cfg): + """Dumps the config to the specified location.""" + with open(cfg_file, "w") as f: + cfg.dump(stream=f) + + +def sweep_setup(): + """Samples cfgs for the sweep.""" + setup_cfg = sweep_cfg.SETUP + # Create output directories + sweep_dir = os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME) + cfgs_dir = os.path.join(sweep_dir, "cfgs") + logs_dir = os.path.join(sweep_dir, "logs") + print("Sweep directory is: {}".format(sweep_dir)) + assert not os.path.exists(logs_dir), "Sweep already started: " + sweep_dir + if os.path.exists(logs_dir) or os.path.exists(cfgs_dir): + print("Overwriting sweep which has not yet launched") + os.makedirs(sweep_dir, exist_ok=True) + os.makedirs(cfgs_dir, exist_ok=True) + # Dump the original sweep_cfg + sweep_cfg_file = os.path.join(sweep_dir, "sweep_cfg.yaml") + os.system("cp {} {}".format(sweep_cfg.SWEEP_CFG_FILE, sweep_cfg_file)) + # Create worker pool for sampling and saving configs + n_proc, chunk = sweep_cfg.NUM_PROC, setup_cfg.CHUNK_SIZE + process_pool = multiprocessing.Pool(n_proc) + # Fix random number generator seed and generate per chunk seeds + np.random.seed(setup_cfg.RNG_SEED) + n_chunks = int(np.ceil(setup_cfg.NUM_SAMPLES / chunk)) + chunk_seeds = np.random.choice(1000000, size=n_chunks, replace=False) + # Sample configs in chunks using multiple workers each with a unique seed + info_str = "Number configs sampled: {}, configs kept: {} [t={:.2f}s]" + n_samples, n_cfgs, i, cfgs, timer = 0, 0, 0, {}, Timer() + while n_samples < setup_cfg.NUM_SAMPLES and n_cfgs < setup_cfg.NUM_CONFIGS: + timer.tic() + seeds = chunk_seeds[i * n_proc : i * n_proc + n_proc] + cfgs_all = process_pool.map(sample_cfgs, seeds) + cfgs = dict(cfgs, **{k: v for d in cfgs_all for k, v in d.items()}) + n_samples, n_cfgs, i = n_samples + chunk * n_proc, len(cfgs), i + 1 + timer.toc() + print(info_str.format(n_samples, n_cfgs, timer.total_time)) + # Randomize cfgs order and subsample if oversampled + keys, cfgs = list(cfgs.keys()), list(cfgs.values()) + n_cfgs = min(n_cfgs, setup_cfg.NUM_CONFIGS) + ids = np.random.choice(len(cfgs), n_cfgs, replace=False) + keys, cfgs = [keys[i] for i in ids], [cfgs[i] for i in ids] + # Save the cfgs and a cfgs_summary + timer.tic() + cfg_names = ["{:06}.yaml".format(i) for i in range(n_cfgs)] + cfgs_summary = {cfg_name: key for cfg_name, key in zip(cfg_names, keys)} + with open(os.path.join(sweep_dir, "cfgs_summary.yaml"), "w") as f: + yaml.dump(cfgs_summary, f, width=float("inf")) + cfg_files = [os.path.join(cfgs_dir, cfg_name) for cfg_name in cfg_names] + process_pool.starmap(dump_cfg, zip(cfg_files, cfgs)) + timer.toc() + print(info_str.format(n_samples, n_cfgs, timer.total_time)) + + +def main(): + desc = "Set up sweep by generating job configs." + sweep_config.load_cfg_fom_args(desc) + sweep_cfg.freeze() + sweep_setup() + + +if __name__ == "__main__": + main()