This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 239
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sweep code for studying model population stats (1 of 2) (#143)
Summary: This is a *major update* and introduces powerful new functionality to pycls. The pycls codebase now provides powerful support for studying *design spaces* and more generally *population statistics* of models as introduced in [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214) and [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678). This idea is that instead of planning a single pycls job (e.g., testing a specific model configuration), one can study the behavior of an entire population of models. This allows for quite powerful and succinct experimental design, and elevates the study of individual model behavior to the study of the behavior of model populations. Please see [`SWEEP_INFO`](docs/SWEEP_INFO.md) for details. This is commit 1 of 2 for the sweep code. It is focused on the sweep config, setting up the sweep, and launching it. Pull Request resolved: #143 Reviewed By: pdollar Differential Revision: D28580825 Pulled By: rajprateek fbshipit-source-id: 9221f0a9b3651642d2c6acd87befacd6521825cc Co-authored-by: Raj Prateek Kosaraju <[email protected]> Co-authored-by: Piotr Dollar <[email protected]>
- Loading branch information
1 parent
2c152a6
commit bd65938
Showing
8 changed files
with
922 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Configuration file (powered by YACS).""" | ||
|
||
import argparse | ||
import getpass | ||
import multiprocessing | ||
import os | ||
import sys | ||
|
||
from pycls.core.config import cfg | ||
from pycls.sweep.samplers import validate_sampler | ||
from yacs.config import CfgNode as CfgNode | ||
|
||
|
||
# Example usage: from sweep.config import sweep_cfg | ||
sweep_cfg = _C = CfgNode() | ||
|
||
|
||
# ------------------------------ General sweep options ------------------------------- # | ||
|
||
# Sweeps root directory where all sweep output subdirectories will be placed | ||
_C.ROOT_DIR = "/checkpoint/{}/sweeps/".format(getpass.getuser()) | ||
|
||
# Sweep name must be unique per sweep and defines the output subdirectory | ||
_C.NAME = "" | ||
|
||
# Optional description of a sweep useful to keep track of sweeps | ||
_C.DESC = "" | ||
|
||
# Number of processes to use for various sweep steps except for running jobs | ||
_C.NUM_PROC = multiprocessing.cpu_count() | ||
|
||
# Automatically overwritten to the file from which the sweep_cfg is loaded | ||
_C.SWEEP_CFG_FILE = "" | ||
|
||
|
||
# ------------------------------- Sweep setup options -------------------------------- # | ||
_C.SETUP = CfgNode() | ||
|
||
# Max number of unique job configs to generate | ||
_C.SETUP.NUM_CONFIGS = 0 | ||
|
||
# Max number of attempts for generating NUM_CONFIGS valid configs | ||
_C.SETUP.NUM_SAMPLES = 1000000 | ||
|
||
# Specifies the chunk size to use per process while sampling configs | ||
_C.SETUP.CHUNK_SIZE = 5000 | ||
|
||
# Random seed for generating job configs | ||
_C.SETUP.RNG_SEED = 0 | ||
|
||
# Base config for all jobs, any valid config option in core.config is valid here | ||
_C.SETUP.BASE_CFG = cfg.clone() | ||
|
||
# Samplers to use for generating job configs (see SAMPLERS defined toward end of file) | ||
# SETUP.SAMPLERS should consists of a dictionary of SAMPLERS | ||
# Each dict key should be a valid parameter in the BASE_CFG (e.g. "MODEL.DEPTH") | ||
# Each dict val should be a valid SAMPLER that defines how to sample (e.g. INT_SAMPLER) | ||
# See the example sweep configs for more usage information | ||
_C.SETUP.SAMPLERS = CfgNode(new_allowed=True) | ||
|
||
# Constraints on generated configs | ||
_C.SETUP.CONSTRAINTS = CfgNode() | ||
|
||
# Complexity constraints CX on models specified as a [LOW, HIGH] range, e.g. [0, 1.0e+6] | ||
# If LOW == HIGH == 0 for a given complexity constraint that constraint is not applied | ||
# For RegNets, if flops<F (B), setting params<3+5.5F and acts<6.5*sqrt(F) (M) works well | ||
_C.SETUP.CONSTRAINTS.CX = CfgNode() | ||
_C.SETUP.CONSTRAINTS.CX.FLOPS = [0, 0] | ||
_C.SETUP.CONSTRAINTS.CX.PARAMS = [0, 0] | ||
_C.SETUP.CONSTRAINTS.CX.ACTS = [0, 0] | ||
|
||
# RegNet specific constraints | ||
_C.SETUP.CONSTRAINTS.REGNET = CfgNode() | ||
_C.SETUP.CONSTRAINTS.REGNET.NUM_STAGES = [4, 4] | ||
|
||
|
||
# ------------------------------- Sweep launch options ------------------------------- # | ||
_C.LAUNCH = CfgNode() | ||
|
||
# Actual script to run for each job (should be in pycls directory) | ||
_C.LAUNCH.SCRIPT = "tools/train_net.py" | ||
|
||
# CONDA environment to use for jobs (defaults to current environment) | ||
_C.LAUNCH.CONDA_ENV = os.environ["CONDA_PREFIX"] | ||
|
||
# Max number of parallel jobs to run (subject to resource constraints) | ||
_C.LAUNCH.PARALLEL_JOBS = 128 | ||
|
||
# Max number of times to retry a job | ||
_C.LAUNCH.MAX_RETRY = 3 | ||
|
||
# Optional comment for sbatch (may be required when using high priority partitions) | ||
_C.LAUNCH.COMMENT = "" | ||
|
||
# Resources to request per job | ||
_C.LAUNCH.NUM_GPUS = 1 | ||
_C.LAUNCH.CPUS_PER_GPU = 10 | ||
_C.LAUNCH.MEM_PER_GPU = 60 | ||
_C.LAUNCH.PARTITION = "learnfair" | ||
_C.LAUNCH.GPU_TYPE = "volta" | ||
_C.LAUNCH.TIME_LIMIT = 4200 | ||
|
||
|
||
# ------------------------------ Sweep collect options ------------------------------- # | ||
_C.COLLECT = CfgNode() | ||
|
||
# Determines which checkpoints to keep, supported options are "all", "last", or "none" | ||
_C.COLLECT.CHECKPOINTS_KEEP = "last" | ||
|
||
|
||
# ------------------------------ Sweep analysis options ------------------------------ # | ||
_C.ANALYZE = CfgNode() | ||
|
||
# List of metrics for which to generate analysis, may be any valid field in log | ||
# An example metric is "cfg.OPTIM.BASE_LR" or "complexity.acts" or "test_epoch.mem" | ||
# Some metrics have shortcuts defined, for example "error" or "lr", see analysis.py | ||
_C.ANALYZE.METRICS = [] | ||
|
||
# List of complexity metrics for which to generate analysis, same format as metrics | ||
_C.ANALYZE.COMPLEXITY = ["flops", "params", "acts"] | ||
|
||
# Controls number of plots of various types to show in analysis | ||
_C.ANALYZE.PLOT_METRIC_VALUES = True | ||
_C.ANALYZE.PLOT_METRIC_TRENDS = True | ||
_C.ANALYZE.PLOT_COMPLEXITY_VALUES = True | ||
_C.ANALYZE.PLOT_COMPLEXITY_TRENDS = True | ||
_C.ANALYZE.PLOT_CURVES_BEST = 0 | ||
_C.ANALYZE.PLOT_CURVES_WORST = 0 | ||
_C.ANALYZE.PLOT_MODELS_BEST = 0 | ||
_C.ANALYZE.PLOT_MODELS_WORST = 0 | ||
|
||
# Undocumented "use at your own risk" feature used to "pre-filter" a sweep | ||
_C.ANALYZE.PRE_FILTERS = CfgNode(new_allowed=True) | ||
|
||
# Undocumented "use at your own risk" feature used to "split" a sweep into sets | ||
_C.ANALYZE.SPLIT_FILTERS = CfgNode(new_allowed=True) | ||
|
||
# Undocumented "use at your own risk" feature used to load other sweeps | ||
_C.ANALYZE.EXTRA_SWEEP_NAMES = [] | ||
|
||
|
||
# --------------------------- Samplers for SETUP.SAMPLERS ---------------------------- # | ||
SAMPLERS = CfgNode() | ||
|
||
# Sampler for uniform sampling from a list of values | ||
SAMPLERS.VALUE_SAMPLER = CfgNode() | ||
SAMPLERS.VALUE_SAMPLER.TYPE = "value_sampler" | ||
SAMPLERS.VALUE_SAMPLER.VALUES = [] | ||
|
||
# Sampler for floats with RAND_TYPE sampling in RANGE quantized to QUANTIZE | ||
# RAND_TYPE can be "uniform", "log_uniform", "power2_uniform", "normal", "log_normal" | ||
# Uses the closed interval RANGE = [LOW, HIGH] (so the HIGH value can be sampled) | ||
# Note that both LOW and HIGH must be divisible by QUANTIZE | ||
# For the (clipped) normal samplers mu/sigma are set so ~99.7% of samples are in RANGE | ||
SAMPLERS.FLOAT_SAMPLER = CfgNode() | ||
SAMPLERS.FLOAT_SAMPLER.TYPE = "float_sampler" | ||
SAMPLERS.FLOAT_SAMPLER.RAND_TYPE = "uniform" | ||
SAMPLERS.FLOAT_SAMPLER.RANGE = [0.0, 0.0] | ||
SAMPLERS.FLOAT_SAMPLER.QUANTIZE = 0.00001 | ||
|
||
# Sampler for ints with RAND_TYPE sampling in RANGE quantized to QUANTIZE | ||
# RAND_TYPE can be "uniform", "log_uniform", "power2_uniform", "normal", "log_normal" | ||
# Uses the closed interval RANGE = [LOW, HIGH] (so the HIGH value can be sampled) | ||
# Note that both LOW and HIGH must be divisible by QUANTIZE | ||
# For the (clipped) normal samplers mu/sigma are set so ~99.7% of samples are in RANGE | ||
SAMPLERS.INT_SAMPLER = CfgNode() | ||
SAMPLERS.INT_SAMPLER.TYPE = "int_sampler" | ||
SAMPLERS.INT_SAMPLER.RAND_TYPE = "uniform" | ||
SAMPLERS.INT_SAMPLER.RANGE = [0, 0] | ||
SAMPLERS.INT_SAMPLER.QUANTIZE = 1 | ||
|
||
# Sampler for a list of LENGTH items each sampled independently by the ITEM_SAMPLER | ||
# The ITEM_SAMPLER can be any sampler (like INT_SAMPLER or even anther LIST_SAMPLER) | ||
SAMPLERS.LIST_SAMPLER = CfgNode() | ||
SAMPLERS.LIST_SAMPLER.TYPE = "list_sampler" | ||
SAMPLERS.LIST_SAMPLER.LENGTH = 0 | ||
SAMPLERS.LIST_SAMPLER.ITEM_SAMPLER = CfgNode(new_allowed=True) | ||
|
||
# RegNet Sampler with ranges for REGNET params (see base config for meaning of params) | ||
# This sampler simply allows a compact specification of a number of RegNet params | ||
# QUANTIZE for each params below is fixed to: 1, 8, 0.1, 0.001, 8, 1/128, respectively | ||
# RAND_TYPE for each is fixed to uni, log, log, log, power2_or_log, power2, respectively | ||
# Default parameter ranges are set to generate fairly good performing models up to 16GF | ||
# For models over 16GF, higher ranges for GROUP_W, W0, and WA are necessary | ||
# If including this sampler set SETUP.CONSTRAINTS as needed | ||
SAMPLERS.REGNET_SAMPLER = CfgNode() | ||
SAMPLERS.REGNET_SAMPLER.TYPE = "regnet_sampler" | ||
SAMPLERS.REGNET_SAMPLER.DEPTH = [12, 28] | ||
SAMPLERS.REGNET_SAMPLER.W0 = [8, 256] | ||
SAMPLERS.REGNET_SAMPLER.WA = [8.0, 256.0] | ||
SAMPLERS.REGNET_SAMPLER.WM = [2.0, 3.0] | ||
SAMPLERS.REGNET_SAMPLER.GROUP_W = [8, 128] | ||
SAMPLERS.REGNET_SAMPLER.BOT_MUL = [1.0, 1.0] | ||
|
||
|
||
# -------------------------------- Utility functions --------------------------------- # | ||
def load_cfg(sweep_cfg_file): | ||
"""Loads config from specified sweep_cfg_file.""" | ||
_C.merge_from_file(sweep_cfg_file) | ||
_C.SWEEP_CFG_FILE = os.path.abspath(sweep_cfg_file) | ||
# Check for required arguments | ||
err_msg = "{} has to be specified." | ||
assert _C.ROOT_DIR, err_msg.format("ROOT_DIR") | ||
assert _C.NAME, err_msg.format("NAME") | ||
assert _C.SETUP.NUM_CONFIGS, err_msg.format("SETUP.NUM_CONFIGS") | ||
# Check for allowed arguments | ||
opts = ["all", "last", "none"] | ||
err_msg = "COLLECT.CHECKPOINTS_KEEP has to be one of {}".format(opts) | ||
assert _C.COLLECT.CHECKPOINTS_KEEP in opts, err_msg | ||
# Setup the base config (note: this only alters the loaded global cfg) | ||
cfg.merge_from_other_cfg(_C.SETUP.BASE_CFG) | ||
# Load and validate each sampler against one of the SAMPLERS templates | ||
for param, sampler in _C.SETUP.SAMPLERS.items(): | ||
_C.SETUP.SAMPLERS[param] = load_sampler(param, sampler) | ||
|
||
|
||
def load_sampler(param, sampler): | ||
"""Loads and validates a sampler against one of the SAMPLERS templates.""" | ||
sampler_type = sampler.TYPE.upper() if "TYPE" in sampler else None | ||
err_msg = "Sampler for '{}' has an unknown or missing TYPE:\n{}" | ||
assert sampler_type in SAMPLERS, err_msg.format(param, sampler) | ||
full_sampler = SAMPLERS[sampler_type].clone() | ||
full_sampler.merge_from_other_cfg(sampler) | ||
validate_sampler(param, full_sampler) | ||
if sampler_type == "LIST_SAMPLER": | ||
full_sampler.ITEM_SAMPLER = load_sampler(param, sampler.ITEM_SAMPLER) | ||
return full_sampler | ||
|
||
|
||
def load_cfg_fom_args(description="Config file options."): | ||
"""Loads sweep cfg from command line argument.""" | ||
parser = argparse.ArgumentParser(description=description) | ||
help_str = "Path to sweep_cfg file" | ||
parser.add_argument("--sweep-cfg", help=help_str, required=True) | ||
args = parser.parse_args() | ||
if len(sys.argv) == 1: | ||
parser.print_help() | ||
sys.exit(1) | ||
load_cfg(args.sweep_cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Functions for sampling in the closed interval [low, high] quantized by q.""" | ||
|
||
from decimal import Decimal | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def quantize(f, q, op=np.floor): | ||
"""Quantizes f to be divisible by q and have q's type.""" | ||
quantized = Decimal(op(f / q)) * Decimal(str(q)) | ||
return type(q)(quantized) | ||
|
||
|
||
def uniform(low, high, q): | ||
"""Samples uniform value from [low, high] quantized to q.""" | ||
# Samples f in [l, h+q) then quantizes f to [l, h] via floor() | ||
# This is equivalent to sampling f in (l-q, h] then quantizing via ceil() | ||
f = np.random.uniform(low, high + q) | ||
return quantize(f, q, np.floor) | ||
|
||
|
||
def log_uniform(low, high, q): | ||
"""Samples log uniform value from [low, high] quantized to q.""" | ||
# Samples f in (l-q*, h] then quantizes f to [l, h] via ceil(), where q*=min(q,l/2) | ||
# This is NOT equivalent to sampling f in [l, h-q) then quantizing via floor() | ||
f = np.exp(-np.random.uniform(-(np.log(high)), -(np.log(low - min(q, low / 2))))) | ||
return quantize(f, q, np.ceil) | ||
|
||
|
||
def power2_uniform(low, high, q): | ||
"""Samples uniform powers of 2 from [low, high] quantized to q.""" | ||
# Samples f2 in [l2, h2+1) then quantizes f2 to [l2, h2] via floor() | ||
f2 = np.floor(np.random.uniform(np.log2(low), np.log2(high) + 1)) | ||
return quantize(2 ** f2, q) | ||
|
||
|
||
def power2_or_log_uniform(low, high, q): | ||
"""Samples uniform powers of 2 or values divisible by q from [low, high].""" | ||
# The overall CDF is log-linear because range in log_uniform is (q/2, high] | ||
f = type(q)(power2_uniform(low, high, low)) | ||
f = log_uniform(max(low, q), high, min(high, q)) if f >= q else f | ||
return f | ||
|
||
|
||
def normal(low, high, q): | ||
"""Samples values from a clipped normal (Gaussian) distribution quantized to q.""" | ||
# mu/sigma are computed from low/high such that ~99.7% of samples are in range | ||
f, mu, sigma = np.inf, (low + high) / 2, (high - low) / 6 | ||
while not low <= f <= high: | ||
f = np.random.normal(mu, sigma) | ||
return quantize(f, q, np.round) | ||
|
||
|
||
def log_normal(low, high, q): | ||
"""Samples values from a clipped log-normal distribution quantized to q.""" | ||
# mu/sigma are computed from low/high such that ~99.7% of samples are in range | ||
log_low, log_high = np.log(low), np.log(high) | ||
f, mu, sigma = np.inf, (log_low + log_high) / 2, (log_high - log_low) / 6 | ||
while not low <= f <= high: | ||
f = np.random.lognormal(mu, sigma) | ||
return quantize(f, q, np.round) | ||
|
||
|
||
rand_types = { | ||
"uniform": uniform, | ||
"log_uniform": log_uniform, | ||
"power2_uniform": power2_uniform, | ||
"power2_or_log_uniform": power2_or_log_uniform, | ||
"normal": normal, | ||
"log_normal": log_normal, | ||
} | ||
|
||
|
||
def validate_rand(err_str, rand_type, low, high, q): | ||
"""Validate parameters to random number generators.""" | ||
err_msg = "{}: {}(low={}, high={}, q={}) is invalid." | ||
err_msg = err_msg.format(err_str, rand_type, low, high, q) | ||
low_q = Decimal(str(low)) % Decimal(str(q)) == 0 | ||
high_q = Decimal(str(high)) % Decimal(str(q)) == 0 | ||
assert type(q) == type(low) == type(high), err_msg | ||
assert rand_type in rand_types, err_msg | ||
assert q > 0 and low <= high, err_msg | ||
assert low > 0 or rand_type in ["uniform", "normal"], err_msg | ||
assert low_q and high_q or rand_type == "power2_or_log_uniform", err_msg | ||
if rand_type in ["power2_uniform", "power2_or_log_uniform"]: | ||
assert all(np.log2(v).is_integer() for v in [low, high, q]), err_msg | ||
|
||
|
||
def plot_rand_cdf(rand_type, low, high, q, n=10000): | ||
"""Visualizes CDF of rand_fun, resulting CDF should be linear (or log-linear).""" | ||
validate_rand("plot_rand_cdf", rand_type, low, high, q) | ||
samples = [rand_types[rand_type](low, high, q) for _ in range(n)] | ||
unique = list(np.unique(samples)) | ||
assert min(unique) >= low and max(unique) <= high, "Sampled value out of range." | ||
cdf = np.cumsum(np.histogram(samples, unique + [np.inf])[0]) / len(samples) | ||
plot_fun = plt.plot if rand_type in ["uniform", "normal"] else plt.semilogx | ||
plot_fun(unique, cdf, "o-", [low, low], [0, 1], "-k", [high, high], [0, 1], "-k") | ||
plot_fun([low, high], [cdf[0], cdf[-1]]) if "normal" not in rand_type else () | ||
plt.title("{}(low={}, high={}, q={})".format(rand_type, low, high, q)) | ||
plt.show() | ||
|
||
|
||
def plot_rand_cdfs(): | ||
"""Visualize CDFs of selected distributions, for visualization/debugging only.""" | ||
plot_rand_cdf("uniform", -0.5, 0.5, 0.1) | ||
plot_rand_cdf("power2_uniform", 2, 512, 1) | ||
plot_rand_cdf("power2_uniform", 0.25, 8.0, 0.25) | ||
plot_rand_cdf("log_uniform", 1, 32, 1) | ||
plot_rand_cdf("log_uniform", 0.5, 16.0, 0.5) | ||
plot_rand_cdf("power2_or_log_uniform", 1.0, 16.0, 1.0) | ||
plot_rand_cdf("power2_or_log_uniform", 0.25, 4.0, 4.0) | ||
plot_rand_cdf("power2_or_log_uniform", 1, 128, 4) |
Oops, something went wrong.