diff --git a/README.md b/README.md index 2807bce..d43ce2a 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,7 @@ ## Introduction -The goal of **pycls** is to provide a simple and flexible codebase for image classification. It is designed to support rapid implementation and evaluation of research ideas. **pycls** also provides a large collection of baseline results ([Model Zoo](MODEL_ZOO.md)). - -The codebase supports efficient single-machine multi-gpu training, powered by the PyTorch distributed package, and provides implementations of standard models including [ResNet](https://arxiv.org/abs/1512.03385), [ResNeXt](https://arxiv.org/abs/1611.05431), [EfficientNet](https://arxiv.org/abs/1905.11946), and [RegNet](https://arxiv.org/abs/2003.13678). +The goal of **pycls** is to provide a simple and flexible codebase for image classification. It is designed to support rapid implementation and evaluation of research ideas. **pycls** also provides a large collection of baseline results ([Model Zoo](MODEL_ZOO.md)). The codebase supports efficient single-machine multi-gpu training, powered by the PyTorch distributed package, and provides implementations of standard models including [ResNet](https://arxiv.org/abs/1512.03385), [ResNeXt](https://arxiv.org/abs/1611.05431), [EfficientNet](https://arxiv.org/abs/1905.11946), and [RegNet](https://arxiv.org/abs/2003.13678). ## Using pycls @@ -21,6 +19,10 @@ Please see [`GETTING_STARTED`](docs/GETTING_STARTED.md) for brief installation i We provide a large set of baseline results and pretrained models available for download in the **pycls** [Model Zoo](MODEL_ZOO.md); including the simple, fast, and effective [RegNet](https://arxiv.org/abs/2003.13678) models that we hope can serve as solid baselines across a wide range of flop regimes. +## Sweep Code + +The pycls codebase now provides powerful support for studying *design spaces* and more generally *population statistics* of models as introduced in [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214) and [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678). This idea is that instead of planning a single pycls job (e.g., testing a specific model configuration), one can study the behavior of an entire population of models. This allows for quite powerful and succinct experimental design, and elevates the study of individual model behavior to the study of the behavior of model populations. Please see [`SWEEP_INFO`](docs/SWEEP_INFO.md) for details. + ## Projects A number of projects at FAIR have been built on top of **pycls**: @@ -28,6 +30,7 @@ A number of projects at FAIR have been built on top of **pycls**: - [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214) - [Exploring Randomly Wired Neural Networks for Image Recognition](https://arxiv.org/abs/1904.01569) - [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678) +- [Fast and Accurate Model Scaling](https://arxiv.org/abs/2103.06877) - [Are Labels Necessary for Neural Architecture Search?](https://arxiv.org/abs/2003.12056) - [PySlowFast Video Understanding Codebase](https://github.com/facebookresearch/SlowFast) @@ -40,22 +43,29 @@ If you find **pycls** helpful in your research or refer to the baseline results ``` @InProceedings{Radosavovic2019, title = {On Network Design Spaces for Visual Recognition}, - author = {Radosavovic, Ilija and Johnson, Justin and Xie, Saining and Lo, Wan-Yen and Doll{\'a}r, Piotr}, + author = {Ilija Radosavovic and Justin Johnson and Saining Xie Wan-Yen Lo and Piotr Doll{\'a}r}, booktitle = {ICCV}, year = {2019} } @InProceedings{Radosavovic2020, title = {Designing Network Design Spaces}, - author = {Radosavovic, Ilija and Kosaraju, Raj Prateek and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr}, + author = {Ilija Radosavovic and Raj Prateek Kosaraju and Ross Girshick and Kaiming He and Piotr Doll{\'a}r}, booktitle = {CVPR}, year = {2020} } + +@InProceedings{Dollar2021, + title = {Fast and Accurate Model Scaling}, + author = {Piotr Doll{\'a}r and Mannat Singh and Ross Girshick}, + booktitle = {CVPR}, + year = {2021} +} ``` ## License -**pycls** is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. +**pycls** is released under the MIT license. Please see the [`LICENSE`](LICENSE) file for more information. ## Contributing diff --git a/configs/sweeps/cifar/cifar_best.yaml b/configs/sweeps/cifar/cifar_best.yaml new file mode 100644 index 0000000..ed6916f --- /dev/null +++ b/configs/sweeps/cifar/cifar_best.yaml @@ -0,0 +1,87 @@ +DESC: + Example CIFAR sweep 3 of 3 (trains the best model from cifar_regnet sweep). + Train the best RegNet-125M from cifar_regnet sweep for variable epoch lengths. + Trains 3 copies of every model (to obtain mean and std of the error). + The purpose of this sweep is to show how to train FINAL version of a model. +NAME: cifar/cifar_best +SETUP: + # Number of configs to sample + NUM_CONFIGS: 12 + # SAMPLERS for optimization parameters + SAMPLERS: + OPTIM.MAX_EPOCH: + TYPE: value_sampler + VALUES: [50, 100, 200, 400] + RNG_SEED: + TYPE: int_sampler + RAND_TYPE: uniform + RANGE: [1, 3] + QUANTIZE: 1 + CONSTRAINTS: + REGNET: + NUM_STAGES: [2, 2] + # BASE_CFG is RegNet-125MF (best model from cifar_regnet sweep) + BASE_CFG: + MODEL: + TYPE: regnet + NUM_CLASSES: 10 + REGNET: + STEM_TYPE: res_stem_cifar + SE_ON: True + STEM_W: 16 + DEPTH: 12 + W0: 96 + WA: 19.5 + WM: 2.942 + GROUP_W: 8 + OPTIM: + BASE_LR: 1.0 + LR_POLICY: cos + MAX_EPOCH: 50 + MOMENTUM: 0.9 + NESTEROV: True + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.0005 + EMA_ALPHA: 0.00025 + EMA_UPDATE_PERIOD: 32 + BN: + USE_CUSTOM_WEIGHT_DECAY: True + TRAIN: + DATASET: cifar10 + SPLIT: train + BATCH_SIZE: 1024 + IM_SIZE: 32 + MIXED_PRECISION: True + LABEL_SMOOTHING: 0.1 + MIXUP_ALPHA: 0.5 + TEST: + DATASET: cifar10 + SPLIT: test + BATCH_SIZE: 1000 + IM_SIZE: 32 + NUM_GPUS: 1 + DATA_LOADER: + NUM_WORKERS: 4 + LOG_PERIOD: 25 + VERBOSE: False +# Launch config options +LAUNCH: + PARTITION: devlab + NUM_GPUS: 1 + PARALLEL_JOBS: 12 + TIME_LIMIT: 180 +# Analyze config options +ANALYZE: + PLOT_METRIC_VALUES: False + PLOT_COMPLEXITY_VALUES: False + PLOT_CURVES_BEST: 3 + PLOT_CURVES_WORST: 0 + PLOT_MODELS_BEST: 1 + METRICS: [] + COMPLEXITY: [flops, params, acts, memory, epoch_fw_bw, epoch_time] + PRE_FILTERS: {done: [0, 1, 1]} + SPLIT_FILTERS: + epochs=050: {cfg.OPTIM.MAX_EPOCH: [ 50, 50, 50]} + epochs=100: {cfg.OPTIM.MAX_EPOCH: [100, 100, 100]} + epochs=200: {cfg.OPTIM.MAX_EPOCH: [200, 200, 200]} + epochs=400: {cfg.OPTIM.MAX_EPOCH: [400, 400, 400]} diff --git a/configs/sweeps/cifar/cifar_optim.yaml b/configs/sweeps/cifar/cifar_optim.yaml new file mode 100644 index 0000000..dc74020 --- /dev/null +++ b/configs/sweeps/cifar/cifar_optim.yaml @@ -0,0 +1,76 @@ +DESC: + Example CIFAR sweep 1 of 3 (find lr and wd for cifar_regnet and cifar_best sweeps). + Tunes the learning rate (lr) and weight decay (wd) for ResNet-56 at 50 epochs. + The purpose of this sweep is to show how to optimize OPTIM parameters. +NAME: cifar/cifar_optim +SETUP: + # Number of configs to sample + NUM_CONFIGS: 64 + # SAMPLERS for optimization parameters + SAMPLERS: + OPTIM.BASE_LR: + TYPE: float_sampler + RAND_TYPE: log_uniform + RANGE: [0.25, 5.0] + QUANTIZE: 1.0e-10 + OPTIM.WEIGHT_DECAY: + TYPE: float_sampler + RAND_TYPE: log_uniform + RANGE: [5.0e-5, 1.0e-3] + QUANTIZE: 1.0e-10 + # BASE_CFG is R-56 with large batch size and stronger augmentation + BASE_CFG: + MODEL: + TYPE: anynet + NUM_CLASSES: 10 + ANYNET: + STEM_TYPE: res_stem_cifar + STEM_W: 16 + BLOCK_TYPE: res_basic_block + DEPTHS: [9, 9, 9] + WIDTHS: [16, 32, 64] + STRIDES: [1, 2, 2] + OPTIM: + BASE_LR: 1.0 + LR_POLICY: cos + MAX_EPOCH: 50 + MOMENTUM: 0.9 + NESTEROV: True + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.0005 + EMA_ALPHA: 0.00025 + EMA_UPDATE_PERIOD: 32 + BN: + USE_CUSTOM_WEIGHT_DECAY: True + TRAIN: + DATASET: cifar10 + SPLIT: train + BATCH_SIZE: 1024 + IM_SIZE: 32 + MIXED_PRECISION: True + LABEL_SMOOTHING: 0.1 + MIXUP_ALPHA: 0.5 + TEST: + DATASET: cifar10 + SPLIT: test + BATCH_SIZE: 1000 + IM_SIZE: 32 + NUM_GPUS: 1 + DATA_LOADER: + NUM_WORKERS: 4 + LOG_PERIOD: 25 + VERBOSE: False +# Launch config options +LAUNCH: + PARTITION: devlab + NUM_GPUS: 1 + PARALLEL_JOBS: 32 + TIME_LIMIT: 60 +# Analyze config options +ANALYZE: + PLOT_CURVES_BEST: 3 + PLOT_METRIC_VALUES: True + PLOT_COMPLEXITY_VALUES: True + METRICS: [lr, wd, lr_wd] + COMPLEXITY: [flops, params, acts, memory, epoch_fw_bw, epoch_time] + PRE_FILTERS: {done: [1, 1, 1]} diff --git a/configs/sweeps/cifar/cifar_regnet.yaml b/configs/sweeps/cifar/cifar_regnet.yaml new file mode 100644 index 0000000..f996b84 --- /dev/null +++ b/configs/sweeps/cifar/cifar_regnet.yaml @@ -0,0 +1,78 @@ +DESC: + Example CIFAR sweep 2 of 3 (uses lr and wd found by cifar_optim sweep). + This sweep searches for a good RegNet-125MF model on cifar (same flops as R56). + The purpose of this sweep is to show how to optimize REGNET parameters. +NAME: cifar/cifar_regnet +SETUP: + # Number of configs to sample + NUM_CONFIGS: 32 + # SAMPLER for RegNet + SAMPLERS: + REGNET: + TYPE: regnet_sampler + DEPTH: [6, 16] + GROUP_W: [1, 32] + # CONSTRAINTS for complexity (roughly based on R-56) + CONSTRAINTS: + CX: + FLOPS: [0.12e+9, 0.13e+9] + PARAMS: [0, 2.0e+6] + ACTS: [0, 1.0e+6] + REGNET: + NUM_STAGES: [2, 2] + # BASE_CFG is R-56 with large batch size and stronger augmentation + BASE_CFG: + MODEL: + TYPE: regnet + NUM_CLASSES: 10 + REGNET: + STEM_TYPE: res_stem_cifar + SE_ON: True + STEM_W: 16 + OPTIM: + BASE_LR: 1.0 + LR_POLICY: cos + MAX_EPOCH: 50 + MOMENTUM: 0.9 + NESTEROV: True + WARMUP_EPOCHS: 5 + WEIGHT_DECAY: 0.0005 + EMA_ALPHA: 0.00025 + EMA_UPDATE_PERIOD: 32 + BN: + USE_CUSTOM_WEIGHT_DECAY: True + TRAIN: + DATASET: cifar10 + SPLIT: train + BATCH_SIZE: 1024 + IM_SIZE: 32 + MIXED_PRECISION: True + LABEL_SMOOTHING: 0.1 + MIXUP_ALPHA: 0.5 + TEST: + DATASET: cifar10 + SPLIT: test + BATCH_SIZE: 1000 + IM_SIZE: 32 + NUM_GPUS: 1 + DATA_LOADER: + NUM_WORKERS: 4 + LOG_PERIOD: 25 + VERBOSE: False +# Launch config options +LAUNCH: + PARTITION: devlab + NUM_GPUS: 1 + PARALLEL_JOBS: 32 + TIME_LIMIT: 60 +# Analyze config options +ANALYZE: + PLOT_METRIC_VALUES: True + PLOT_COMPLEXITY_VALUES: True + PLOT_CURVES_BEST: 3 + PLOT_CURVES_WORST: 0 + PLOT_MODELS_BEST: 8 + PLOT_MODELS_WORST: 0 + METRICS: [regnet_depth, regnet_w0, regnet_wa, regnet_wm, regnet_gw] + COMPLEXITY: [flops, params, acts, memory, epoch_fw_bw, epoch_time] + PRE_FILTERS: {done: [0, 1, 1]} diff --git a/docs/DATA.md b/docs/DATA.md index e61817e..5002fa0 100644 --- a/docs/DATA.md +++ b/docs/DATA.md @@ -36,14 +36,14 @@ Create a directory containing symlinks: mkdir -p /path/pycls/pycls/datasets/data ``` -Symlink ImageNet: +Symlink ImageNet (`/datasets01/imagenet_full_size/061417/` on FAIR cluster): ``` -ln -s /path/imagenet /path/pycls/pycls/datasets/data/imagenet +ln -sv /path/imagenet /path/pycls/pycls/datasets/data/imagenet ``` -Symlink CIFAR-10: +Symlink CIFAR-10 (`/datasets01/cifar-10-batches-py/060817/` on FAIR cluster): ``` -ln -s /path/cifar10 /path/pycls/pycls/datasets/data/cifar10 +ln -sv /path/cifar10 /path/pycls/pycls/datasets/data/cifar10 ``` diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index a97cf63..477b0f0 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -97,7 +97,7 @@ python tools/time_net.py PREC_TIME.NUM_ITER 50 ``` -### MODEL SCALING +### Model Scaling Scale a RegNetY-4GF by 4x using fast compound scaling (see https://arxiv.org/abs/2103.06877): diff --git a/docs/SWEEP_INFO.md b/docs/SWEEP_INFO.md new file mode 100644 index 0000000..2c0453a --- /dev/null +++ b/docs/SWEEP_INFO.md @@ -0,0 +1,148 @@ +# Sweeps and Design Spaces + +The *sweep* code in pycls provides support for studying *design spaces* and more generally *population statistics* of models. This idea is that instead of planning a single pycls job (e.g., testing a specific model configuration), one can study the behavior of an entire population of models. This allows for quite powerful and succinct experimental design, and elevates the study of individual model behavior to the study of the behavior of model populations. + +This doc is organized as follows: +- [Introduction and background](#introduction-and-background) +- [Sweep prerequisites](#sweep-prerequisites) +- [Sweep overview](#sweep-overview) +- [Sweep examples](#sweep-examples) + +## Introduction and background + +The concept of network *design spaces* was introduced in [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214) and [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678). A design space is a large, possibly infinite, population of model architectures. The core insight is that we can sample models from a design space, giving rise to a model distribution, and turn to tools from classical statistics to analyze the design space. Instead of studying the behavior of an individual model, we study the *population statistics* of a collection of models. For example, when studying network design, we can aim to find a single best model under a specific setting (as in model search), or we can aim to study a diverse population of models to understand more general design principles that make for effective models. Typically, the latter allows us to learn and *generalize* to new settings, and makes for more robust findings that are likely to hold under more diverse settings. We recommend reading the above mentioned papers for further motivation and details. + +We operationalize the study of design spaces and population statistics by introducing a very flexible notion of *sweeps*. Simply put, a sweep is a population level experiment consisting of a number of individual pycls jobs. As mentioned, studying the *population statistics* of models can be more informative than designing and testing an individual model. A sweep can be as simple as a grid or random search over a hyperparameter (e.g., model depth, learning rate, etc.). However, a sweep can be far more general than a hyperparameter search, and can be used to study the behavior of diverse populations of models simultaneously varying along many dimensions. + +Just like a single training job in pycls is defined by a config, a sweep is likewise defined by a *sweep config*. This meta-level sweep config is an extremely powerful concept and elevates experimental design to the sweep level. So, rather than creating a large number of individual pycls configs, we can create a single sweep config to generate and study a population of models. A sweep config defines everything about a sweep, including: (1) the sweep setup options (defines how to sample pycls configs), (2) sweep launch options (how to launch the sweep to a cluster), and (3) sweep analysis options (to generate an analysis of the model population statistics). Rather than going into more detail here, we suggest studying the examples below and looking at the code documentation. + +## Sweep prerequisites + +Before beginning with the sweep code, please make sure to complete the following steps: + +- Please ensure that you can run individual pycls jobs by following the steps in [`GETTING_STARTED.md`](GETTING_STARTED.md). You should be able to successfully run an individual pycls job (e.g., training a model) prior to running a sweep. + +- Instructions and tools mentioned here were designed to work on *SLURM managed clusters*. While most of the code could easily be adopted to other clusters (in particular only [`sweep_launch.py`](../tools/sweep_launch.py) and [`sweep_launch_job.py`](../tools/sweep_launch_job.py) would likely need to be altered for a non-SLURM managed cluster), the pycls code only supports SLURM managed clusters out of the box. + +- For simplicity and to successfully run the examples below, we recommend changing all the files in ./tools to be executable by the user (for example: `chmod 744 ./tools/*.py`). Of course, as an alternative, one can instead execute any of the python scripts invoking python explicitly. + +## Sweep overview + +### The sweep config + +A sweep config consists of four main parts (usage described in more detail shortly): +- `SETUP` options: used to specify a base pycls config along with samplers +- `LAUNCH` options: used to specify options for launching the sweep on the cluster +- `COLLECT` options: used to specify options for collecting the sweep results +- `ANALYSIS` options: used to specify option for analyzing the sweep results + +In addition to these parts, there are a few top-level options that should be set, including: +- `ROOT_DIR`: root directory where all sweep output subdirectories will be placed +- `NAME`: the sweep name must be unique and defines the output subdirectory(s) + +For full documentation see: [`sweep/config.py`](../pycls/sweep/config.py). It is easier to get started by looking at the example sweeps at the end of this doc prior to looking at the full documentation. + +### Setting up a sweep + +[`sweep_setup.py`](../tools/sweep_setup.py): Setting up a sweep generates the individual pycls job configs necessary to launch the sweep. This make take some time (many minutes) if sampling many configs or if it is difficult to find configs that generate the sampling constraints. Once the sweep config is defined, the sweep can be set up via: +``` +SWEEP_CFG=path/to/config.yaml +./tools/sweep_setup.py --sweep-cfg $SWEEP_CFG +``` +In this and following examples we assume the sweep config is stored at `path/to/config.yaml`. + +The following files are created in the output directory: +``` +ROOT_DIR/NAME/cfgs/??????.yaml # numbered configs for individual pycls jobs +ROOT_DIR/NAME/cfgs_summary.yaml # summary of the generated cfgs +ROOT_DIR/NAME/sweep_cfg.yaml # copy of the original sweep configuration +``` +Here ROOT_DIR and NAME are the fields specified in the sweep config. Note that before launching the sweep, you should spot check some of the generated configs and the cfgs_summary.yaml to see if the generated pycls configs look reasonable. You can run the sweep_setup command repeatedly so long as you have not launched the sweep. + +### Launching a sweep + +[`sweep_launch.py`](../tools/sweep_launch.py): Launching a sweep sends the individual pycls jobs to a SLURM managed cluster using the options in LAUNCH of the sweep config. The launch is fairly quick, although obviously the individual pycls jobs may run for long periods of time. The sweep can be launched via: +``` +./tools/sweep_launch.py --sweep-cfg $SWEEP_CFG +``` +The following files are created in the output directory: +``` +ROOT_DIR/NAME/logs/??????/* # results of each individual pycls job +ROOT_DIR/NAME/logs/sbatch/* # SLURM log files fo reach pycls job +ROOT_DIR/NAME/pycls/* # copy of pycls code for basic job isolation +``` +A sweep should only be launched once. If the sweep is fully stopped, you can resume it by calling the sweep_launch command again. While the sweep is running, you can monitor individual jobs by looking in the individual pycls log output directories, or by collecting the sweep (described next). + +Note that standard SLURM commands can be used to monitor the sweep and individual jobs, cancel a sweep, requeue it, etc. See the [SLURM documentation](https://slurm.schedmd.com/documentation.html) for more information. Common useful SLURM commands include: +``` +squeue --me +scontrol requeue JOBID +scancel JOBID +sinfo -o '%f %A %N %m %G' | column -t +``` + +### Collecting a sweep + +[`sweep_collect.py`](../tools/sweep_collect.py): Collecting a sweep gathers core information from each individual pycls job and places it into a single large json file. Note that collecting a sweep is also a great way to see the *status* of a sweep, and the sweep collection can be run an unlimited number of times. The command for this is: +``` +./tools/sweep_collect.py --sweep-cfg $SWEEP_CFG +``` +The following files are created in the output directory: +``` +ROOT_DIR/NAME/sweep.json # output file with all of the sweep information +``` + +### Analyzing a sweep + +[`sweep_analyze.py`](../tools/sweep_analyze.py): Analyzing a sweep is the final step in the life cycle of a sweep. Note that the analysis can be run as soon as partial results are collected (via sweep_collect.py) and does not require the sweep to be finished or even for any individual pycls jobs to be finished. The analysis depends on the options in ANALYSIS part of the sweep config; note that it is typical to reanalyze the data multiple times while altering the ANALYSIS options. The command for analysis is: +``` +./tools/sweep_analyze.py --sweep-cfg $SWEEP_CFG +``` +The following files are created in the output directory: +``` +ROOT_DIR/NAME/analysis.html # html file containing the sweep analysis +``` +After generating the analysis, the analysis.html file can be viewed in any browser. As it’s a fully self-contained html file (with embedded vector images), it can also be easily shared. + +## Sweep examples + +We provide three example sweeps (along with their output): +- Sweep config: [`cifar_optim`](../configs/sweeps/cifar/cifar_optim.yaml) | Analysis: [cifar_optim_analysis](https://dl.fbaipublicfiles.com/pycls/sweeps/cifar/cifar_optim_analysis.html) +- Sweep config: [`cifar_regnet`](../configs/sweeps/cifar/cifar_regnet.yaml) | Analysis: [cifar_regnet_analysis](https://dl.fbaipublicfiles.com/pycls/sweeps/cifar/cifar_regnet_analysis.html) +- Sweep config: [`cifar_best`](../configs/sweeps/cifar/cifar_best.yaml) | Analysis: [cifar_best_analysis](https://dl.fbaipublicfiles.com/pycls/sweeps/cifar/cifar_best_analysis.html) + +We suggest looking at each example config carefully to understand the setup process. We will go through the [`cifar_optim`](../configs/sweeps/cifar/cifar_optim.yaml) config next in more detail. The other two example configs can serve as additional reference point and demonstrate various simple use cases. + +The [`cifar_optim`](../configs/sweeps/cifar/cifar_optim.yaml) config starts with a `DESC` and `NAME` field which are self-explanatory. Next comes the `SETUP` section. In `SETUP`, the `NUM_CONFIGS: 64` field indicates that we will sample 64 individual pycls configs. Next are the `SAMPLERS`: +``` +SAMPLERS: + OPTIM.BASE_LR: + TYPE: float_sampler + RAND_TYPE: log_uniform + RANGE: [0.25, 5.0] + QUANTIZE: 1.0e-10 + OPTIM.WEIGHT_DECAY: + TYPE: float_sampler + RAND_TYPE: log_uniform + RANGE: [5.0e-5, 1.0e-3] + QUANTIZE: 1.0e-10 +``` +There are two `SAMPLERS`, one for the `OPTIM.BASE_LR` and the other for `OPTIM.WEIGHT_DECAY`. Both of these sample floats using a log-uniform distribution with the ranges and quantization as specified. This means that for every sampled config, these two corresponding fields will be sampled from these distributions. The `SAMPLERS` are typically a critical aspect of the sweep config and control how individual pycls configs are generated. There is a lot of flexibility in the type of sampler to use and the distribution from which to sample, see the `SAMPLERS` section of the [`sweep/config.py`](../pycls/sweep/config.py) for more details. Note that one can sample any parameter that is part of the base pycls config. In addition , one can put `CONSTRAINTS` on the sampled configs, this functionality is not used in this example but is used in the [`cifar_regnet`](../configs/sweeps/cifar/cifar_regnet.yaml) example. + +Next, after the samplers comes the `BASE_CFG`. The `BASE_CFG` is simply a standard pycls config. Every sampled config will be the `BASE_CFG` with values specified by the `SAMPLERS` (like `OPTIM.BASE_LR`) overwritten. In this example, the `BASE_CFG` is simply ResNet-56 with some strong data augmentation. Note, however, that the epoch length (`OPTIM.MAX_EPOCH`) is set to a fairly short 50 epochs to allow for each individual job to be fast. Typically, when generating a sweep, we keep the epochs per model low and focus not on absolute performance but on observed trends. + +Next comes the `LAUNCH` options. Note that these may need to be customized to different cluster setups (e.g., the `PARTITION` field will likely need to be changed for different clusters). Finally the `ANALYZE` options control the generated analysis html file. For example, in this case we plot `METRICS: [lr, wd, lr_wd]`, meaning we plot the learning rate, weight decay, and the product of the two. These are simply shortcuts to the corresponding fields in the config; these shortcuts are defined in [`sweep/analysis.py`](../pycls/sweep/analysis.py). + +Take a look at the generated [cifar_optim_analysis](https://dl.fbaipublicfiles.com/pycls/sweeps/cifar/cifar_optim_analysis.html). First, there are Error Distribution Functions (EDF) for the trained models, see [On Network Design Spaces for Visual Recognition](https://arxiv.org/abs/1905.13214). Next there are plots showing error versus learning rate, weight decay, and the product of the two. An interesting observation is that the product of learning rate and weight decay are most predictive of error. Next are plots of error versus various complexity metrics (note, however, that the model is fixed in all the configs so the complexity metrics don’t vary). Finally, training and testing curves are shown for the best three models. As discussed, the analysis is fully customizable (see the options in `ANALYSIS`). + +Finally, for reference, to run the sweep in its entirety, the steps are: +``` +# from within the pycls root directory: +SWEEP_CFG=configs/sweeps/cifar/cifar_optim.yaml +./tools/sweep_setup.py --sweep-cfg $SWEEP_CFG +./tools/sweep_launch.py --sweep-cfg $SWEEP_CFG +./tools/sweep_collect.py --sweep-cfg $SWEEP_CFG +./tools/sweep_analyze.py --sweep-cfg $SWEEP_CFG +``` + +The best way to learn more about sweeps is to set up your own sweeps. The sweep config system is quite powerful and allows for many interesting experiments. diff --git a/pycls/sweep/analysis.py b/pycls/sweep/analysis.py new file mode 100644 index 0000000..e9b2810 --- /dev/null +++ b/pycls/sweep/analysis.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +"""Sweep analysis functions.""" + +import json +from functools import reduce +from operator import getitem + +import numpy as np + + +def load_sweep(sweep_file): + """Loads sweep data from a file.""" + with open(sweep_file, "r") as f: + sweep = json.load(f) + sweep = [data for data in sweep if "test_ema_epoch" in data] + augment_sweep(sweep) + return sweep + + +def augment_sweep(sweep): + """Augments sweep data with fields useful for analysis.""" + # Augment data with "aug" field + for data in sweep: + data["aug"] = {} + # Augment with "aug.lr_wd" field = log(lr) + log(wd) = log(lr * wd) + lrs, wds = get_vals(sweep, "lr"), get_vals(sweep, "wd") + for data, lr, wd in zip(sweep, lrs, wds): + data["aug"]["lr_wd"] = lr + wd + # Augment with "aug.done" field + epoch_ind = get_vals(sweep, "test_epoch.epoch_ind") + epoch_max = get_vals(sweep, "test_epoch.epoch_max") + for data, i, m in zip(sweep, epoch_ind, epoch_max): + data["aug"]["done"] = i[-1] / m[-1] + # Augment with "ema_gain" + errors = get_vals(sweep, "test_epoch.min_top1_err") + errors_ema = get_vals(sweep, "test_ema_epoch.min_top1_err") + for data, error, error_ema in zip(sweep, errors, errors_ema): + data["aug"]["ema_gain"] = max(0, min(error) - min(error_ema)) + + +def sort_sweep(sweep, metric, reverse=False): + """Sorts sweep by any metric (including non scalar metrics).""" + keys = get_vals(sweep, metric) + keys = [k if np.isscalar(k) else json.dumps(k, sort_keys=True) for k in keys] + keys, sweep = zip(*sorted(zip(keys, sweep), key=lambda k: k[0], reverse=reverse)) + return sweep, keys + + +def describe_sweep(sweep, reverse=False): + """Generate a string description of sweep.""" + keys = ["error_ema", "error_tst", "done", "log_file", "cfg.DESC"] + formats = ["ema={:.2f}", "err={:.2f}", "done={:.2f}", "{}", "{}"] + vals = [get_vals(sweep, key) for key in keys] + vals[3] = [v.split("/")[-2] for v in vals[3]] + desc = [" | ".join(formats).format(*val) for val in zip(*vals)] + desc = [s for _, s in sorted(zip(vals[0], desc), reverse=reverse)] + return "\n".join(desc) + + +metrics_info = { + # Each metric has the form [compound_key, label, transform] + "error": ["test_ema_epoch.min_top1_err", "", min], + "error_ema": ["test_ema_epoch.min_top1_err", "", min], + "error_tst": ["test_epoch.min_top1_err", "", min], + "done": ["aug.done", "fraction done", None], + "epochs": ["cfg.OPTIM.MAX_EPOCH", "epochs", None], + # Complexity metrics + "flops": ["complexity.flops", "flops (B)", lambda v: v / 1e9], + "params": ["complexity.params", "params (M)", lambda v: v / 1e6], + "acts": ["complexity.acts", "activations (M)", lambda v: v / 1e6], + "memory": ["train_epoch.mem", "memory (GB)", lambda v: max(v) / 1e3], + "resolution": ["cfg.TRAIN.IM_SIZE", "resolution", None], + "epoch_fw_bw": ["epoch_times.train_fw_bw_time", "epoch fw_bw time (s)", None], + "epoch_time": ["train_epoch.time_epoch", "epoch total time (s)", np.mean], + "batch_size": ["cfg.TRAIN.BATCH_SIZE", "batch size", None], + # Regnet metrics + "regnet_depth": ["cfg.REGNET.DEPTH", "depth", None], + "regnet_w0": ["cfg.REGNET.W0", "w0", None], + "regnet_wa": ["cfg.REGNET.WA", "wa", None], + "regnet_wm": ["cfg.REGNET.WM", "wm", None], + "regnet_gw": ["cfg.REGNET.GROUP_W", "gw", None], + "regnet_bm": ["cfg.REGNET.BOT_MUL", "bm", None], + # Anynet metrics + "anynet_ds": ["cfg.ANYNET.DEPTHS", "ds", None], + "anynet_ws": ["cfg.ANYNET.WIDTHS", "ws", None], + "anynet_gs": ["cfg.ANYNET.GROUP_WS", "gs", None], + "anynet_bs": ["cfg.ANYNET.BOT_MULS", "bs", None], + "anynet_d": ["cfg.ANYNET.DEPTHS", "d", sum], + "anynet_w": ["cfg.ANYNET.WIDTHS", "w", max], + "anynet_g": ["cfg.ANYNET.GROUP_WS", "g", max], + "anynet_b": ["cfg.ANYNET.BOT_MULS", "b", max], + # Effnet metrics + "effnet_ds": ["cfg.EN.DEPTHS", "ds", None], + "effnet_ws": ["cfg.EN.WIDTHS", "ws", None], + "effnet_ss": ["cfg.EN.STRIDES", "ss", None], + "effnet_bs": ["cfg.EN.EXP_RATIOS", "bs", None], + "effnet_d": ["cfg.EN.DEPTHS", "d", sum], + "effnet_w": ["cfg.EN.WIDTHS", "w", max], + # Optimization metrics + "lr": ["cfg.OPTIM.BASE_LR", r"log$_{10}(lr)$", np.log10], + "min_lr": ["cfg.OPTIM.MIN_LR", r"min_lr", None], + "wd": ["cfg.OPTIM.WEIGHT_DECAY", r"log$_{10}(wd)$", np.log10], + "lr_wd": ["aug.lr_wd", r"log$_{10}(lr \cdot wd)$", None], + "bn_wd": ["cfg.BN.CUSTOM_WEIGHT_DECAY", r"log$_{10}$(bn_wd)", np.log10], + "momentum": ["cfg.OPTIM.MOMENTUM", "", None], + "ema_alpha": ["cfg.OPTIM.EMA_ALPHA", r"log$_{10}$(ema_alpha)", np.log10], + "ema_beta": ["cfg.OPTIM.EMA_BETA", r"log$_{10}$(ema_beta)", np.log10], + "ema_update": ["cfg.OPTIM.EMA_UPDATE_PERIOD", r"log$_{10}$(ema_update)", np.log2], +} + + +def get_info(metric): + """Returns [compound_key, label, transform] for metric.""" + info = metrics_info[metric] if metric in metrics_info else [metric, metric, None] + info[1] = info[1] if info[1] else metric + return info + + +def get_vals(sweep, metric): + """Gets values for given metric (transformed if metric transform is specified).""" + compound_key, _, transform = get_info(metric) + metric_keys = compound_key.split(".") + vals = [reduce(getitem, metric_keys, data) for data in sweep] + vals = [transform(v) for v in vals] if transform else vals + return vals + + +def get_filters(sweep, metrics, alpha=5, sample=0.25, b=2500): + """Use empirical bootstrap to estimate filter ranges per metric for good errors.""" + assert len(sweep), "Sweep cannot be empty." + errs = np.array(get_vals(sweep, "error")) + n, b, filters = len(errs), int(b), {} + percentiles = [alpha / 2, 50, 100 - alpha / 2] + n_sample = int(sample) if sample > 1 else max(1, int(n * sample)) + samples = [np.random.choice(n, n_sample) for _ in range(b)] + samples = [s[np.argmin(errs[s])] for s in samples] + for metric in metrics: + vals = np.array(get_vals(sweep, metric)) + vals = [vals[s] for s in samples] + v_min, v_med, v_max = tuple(np.percentile(vals, percentiles)) + filters[metric] = [v_min, v_med, v_max] + return filters + + +def apply_filters(sweep, filters): + """Filter sweep according to dict of filters of form {metric: [min, med, max]}.""" + filters = filters if filters else {} + for metric, (v_min, _, v_max) in filters.items(): + keep = [v_min <= v <= v_max for v in get_vals(sweep, metric)] + sweep = [data for k, data in zip(keep, sweep) if k] + return sweep diff --git a/pycls/sweep/htmlbook.py b/pycls/sweep/htmlbook.py new file mode 100644 index 0000000..0555c28 --- /dev/null +++ b/pycls/sweep/htmlbook.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Htmlbook - Piotr's lightweight alternative to notebooks.""" + +import base64 +from io import BytesIO + +import matplotlib.pyplot as plt +from yattag import Doc, indent + + +class Htmlbook: + """An Htmlbook is used to generate an html page from text and matplotlib figures.""" + + def __init__(self, title): + """Initializes Htmlbook with a given title.""" + # The doc is used for the body of the document + self.doc, self.tag, self.text, self.line = Doc().ttl() + # The top_doc is used for the title and table of contents + self.top_doc, self.top_tag, self.top_text, self.top_line = Doc().ttl() + # Add link anchor and title to the top_doc + self.top_line("a", "", name="top") + self.top_line("h1", title) + self.section_counter = 1 + + def add_section(self, name): + """Adds a section to the Htmlbook (also updates table of contents).""" + anchor = "section{:03d}".format(self.section_counter) + name = str(self.section_counter) + " " + name + anchor_style = "text-decoration: none;" + self.section_counter += 1 + # Add section to main text + self.doc.stag("br") + self.doc.stag("hr", style="border: 2px solid") + with self.tag("h3"): + self.line("a", "", name=anchor) + self.text(name + " ") + self.line("a", "[top]", href="#top", style=anchor_style) + # Add section to table of contents + self.top_line("a", name, href="#" + anchor, style=anchor_style) + self.top_doc.stag("br") + + def add_plot(self, matplotlib_figure, ext="svg", **kwargs): + """Adds a matplotlib figure embedded directly into the html.""" + out = BytesIO() + matplotlib_figure.savefig(out, format=ext, bbox_inches="tight", **kwargs) + plt.close(matplotlib_figure) + if ext == "svg": + self.doc.asis("".format(ext, out)) + self.doc.stag("br") + + def add_details(self, summary, details): + """Adds a collapsible details section to Htmlbook.""" + with self.tag("details"): + self.line("summary", summary) + self.line("pre", details) + + def to_text(self): + """Generates a string representing the Htmlbook (including figures).""" + return indent(self.top_doc.getvalue() + self.doc.getvalue()) + + def to_file(self, out_file): + """Saves Htmlbook to a file (typically should have .html extension).""" + with open(out_file, "w") as file: + file.write(self.to_text()) diff --git a/pycls/sweep/plotting.py b/pycls/sweep/plotting.py new file mode 100644 index 0000000..fbc847b --- /dev/null +++ b/pycls/sweep/plotting.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Sweep plotting functions.""" + +import matplotlib.lines as lines +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import pycls.models.regnet as regnet +from pycls.sweep.analysis import get_info, get_vals, sort_sweep + + +# Global color scheme and fill color +_COLORS, _COLOR_FILL = [], [] + + +def set_plot_style(): + """Sets default plotting styles for all plots.""" + plt.rcParams["figure.figsize"] = [3.0, 2] + plt.rcParams["axes.linewidth"] = 1 + plt.rcParams["axes.grid"] = True + plt.rcParams["grid.alpha"] = 0.4 + plt.rcParams["xtick.bottom"] = False + plt.rcParams["ytick.left"] = False + plt.rcParams["legend.edgecolor"] = "0.3" + plt.rcParams["axes.xmargin"] = 0.025 + plt.rcParams["lines.linewidth"] = 1.25 + plt.rcParams["lines.markersize"] = 5.0 + plt.rcParams["font.size"] = 10 + plt.rcParams["axes.titlesize"] = 10 + plt.rcParams["legend.fontsize"] = 8 + plt.rcParams["legend.title_fontsize"] = 8 + plt.rcParams["xtick.labelsize"] = 7 + plt.rcParams["ytick.labelsize"] = 7 + + +def set_colors(colors=None): + """Sets the global color scheme (colors should be a list of rgb float values).""" + global _COLORS + default_colors = [ + [0.000, 0.447, 0.741], + [0.850, 0.325, 0.098], + [0.929, 0.694, 0.125], + [0.494, 0.184, 0.556], + [0.466, 0.674, 0.188], + [0.301, 0.745, 0.933], + [0.635, 0.078, 0.184], + [0.300, 0.300, 0.300], + [0.600, 0.600, 0.600], + [1.000, 0.000, 0.000], + ] + colors = default_colors if colors is None else colors + colors, n = np.array(colors), len(colors) + err_str = "Invalid colors list: {}".format(colors) + assert ((colors >= 0) & (colors <= 1)).all() and colors.shape[1] == 3, err_str + _COLORS = np.tile(colors, (int(np.ceil((10000 / n))), 1)).reshape((-1, 3)) + + +def set_color_fill(color_fill=None): + """Sets the global color fill (color should be a set of rgb float values).""" + global _COLOR_FILL + _COLOR_FILL = [0.000, 0.447, 0.741] if color_fill is None else color_fill + + +def get_color(ind=(), scale=1, dtype=float): + """Gets color (or colors) referenced by index (or indices).""" + return np.ndarray.astype(_COLORS[ind] * scale, dtype) + + +def fig_make(m_rows, n_cols, flatten, **kwargs): + """Gets figure for plotting with m x n axes.""" + figsize = plt.rcParams["figure.figsize"] + figsize = (figsize[0] * n_cols, figsize[1] * m_rows) + fig, axes = plt.subplots(m_rows, n_cols, figsize=figsize, squeeze=False, **kwargs) + axes = [ax for axes in axes for ax in axes] if flatten else axes + return fig, axes + + +def fig_legend(fig, n_cols, names, colors=None, styles=None, markers=None): + """Adds legend to figure and tweaks layout (call after fig is done).""" + n, c, s, m = len(names), colors, styles, markers + c = c if c else get_color()[:n] + s = [""] * n if s is None else [s] * n if type(s) == str else s + m = ["o"] * n if m is None else [m] * n if type(m) == str else m + n_cols = int(np.ceil(n / np.ceil(n / n_cols))) + hs = [lines.Line2D([0], [0], color=c, ls=s, marker=m) for c, s, m in zip(c, s, m)] + fig.legend(hs, names, bbox_to_anchor=(0.5, 1.0), loc="lower center", ncol=n_cols) + fig.tight_layout(pad=0.3, h_pad=1.08, w_pad=1.08) + + +def plot_edf(sweeps, names): + """Plots error EDF for each sweep.""" + m, n = 1, 1 + fig, axes = fig_make(m, n, True) + for i, sweep in enumerate(sweeps): + k = len(sweep) + errs = sorted(get_vals(sweep, "error")) + edf = np.cumsum(np.ones(k) / k) + label = "{:3d}|{:.1f}|{:.1f}".format(k, min(errs), np.mean(errs)) + axes[0].plot(errs, edf, "-", alpha=0.8, c=get_color(i), label=label) + axes[0].legend(loc="lower right", title=" " * 10 + "n|min|mean") + axes[0].set_xlabel("error") + axes[0].set_ylabel("cumulative prob.") + fig_legend(fig, n, names, styles="-", markers="") + return fig + + +def plot_values(sweeps, names, metrics, filters): + """Plots scatter plot of error versus metric for each metric and sweep.""" + m, n, c = len(metrics), len(sweeps), _COLOR_FILL + fig, axes = fig_make(m, n, False, sharex="row", sharey=True) + e_min = min(min(get_vals(sweep, "error")) for sweep in sweeps) + e_max = max(max(get_vals(sweep, "error")) for sweep in sweeps) + e_min, e_max = e_min - (e_max - e_min) / 5, e_max + (e_max - e_min) / 5 + for i, j in [(i, j) for i in range(m) for j in range(n)]: + metric, sweep, ax, f = metrics[i], sweeps[j], axes[i][j], filters[j] + errs = get_vals(sweep, "error") + vals = get_vals(sweep, metric) + v_min, v_med, v_max = f[metric] + f = [float(str("{:.3e}".format(f))) for f in f[metric]] + l_rng, l_med = "[{}, {}]".format(f[0], f[2]), "best: {}".format(f[1]) + ax.scatter(vals, errs, color=get_color(j), alpha=0.8) + ax.plot([v_med, v_med], [e_min, e_max], c="k", label=l_med) + ax.fill_between([v_min, v_max], e_min, e_max, alpha=0.1, color=c, label=l_rng) + ax.legend(loc="upper left") + ax.set_ylabel("error" if j == 0 else "") + ax.set_xlabel(get_info(metric)[1]) + fig_legend(fig, n, names) + return fig + + +def plot_values_2d(sweeps, names, metric_pairs): + """Plots color-coded scatter plot for each metric_pair and sweep.""" + m, n = len(metric_pairs), len(sweeps) + fig, axes = fig_make(m, n, False, sharex="row", sharey="row") + for i, j in [(i, j) for i in range(m) for j in range(n)]: + sweep, ax = sweeps[j], axes[i][j] + metric_x, metric_y = metric_pairs[i] + xs = get_vals(sweep, metric_x) + ys = get_vals(sweep, metric_y) + errs = get_vals(sweep, "error") + ranks = (np.argsort(np.argsort(errs)) + 1) / len(errs) + s = ax.scatter(xs, ys, c=ranks, alpha=0.6) + ax.set_xlabel(get_info(metric_x)[1], fontsize=12) + ax.set_ylabel(get_info(metric_y)[1], fontsize=12) + fig.colorbar(s, ax=ax) if j == n - 1 else () + fig_legend(fig, n, names) + return fig + + +def plot_trends(sweeps, names, metrics, filters, max_cols=0): + """Plots metric versus sweep for each metric.""" + n_metrics, xs = len(metrics), range(len(sweeps)) + max_cols = max_cols if max_cols else len(sweeps) + m = int(np.ceil(n_metrics / max_cols)) + n = min(max_cols, int(np.ceil(n_metrics / m))) + fig, axes = fig_make(m, n, True, sharex=False, sharey=False) + [ax.axis("off") for ax in axes[n_metrics::]] + for ax, metric in zip(axes, metrics): + # Get values to plot + vals = [get_vals(sweep, metric) for sweep in sweeps] + vs_min, vs_max = [min(v) for v in vals], [max(v) for v in vals] + fs_min, fs_med, fs_max = zip(*[f[metric] for f in filters]) + # Show full range + ax.plot(xs, vs_min, "-", xs, vs_max, "-", c="0.7") + ax.fill_between(xs, vs_min, vs_max, alpha=0.05, color=_COLOR_FILL) + # Show good range + ax.plot(xs, fs_min, "-", xs, fs_max, "-", c="0.5") + ax.fill_between(xs, fs_min, fs_max, alpha=0.10, color=_COLOR_FILL) + # Show best range + ax.plot(xs, fs_med, "-o", c="k") + # Show good range with markers + ax.scatter(xs, fs_min, c=get_color(xs), marker="^", s=80, zorder=10) + ax.scatter(xs, fs_max, c=get_color(xs), marker="v", s=80, zorder=10) + # Finalize axis + ax.set_ylabel(get_info(metric)[1]) + ax.set_xticks([]) + ax.set_xlabel("sweep") + fig_legend(fig, n, names, markers="D") + return fig + + +def plot_curves(sweeps, names, metric, n_curves, reverse=False): + """Plots metric versus epoch for up to best n_curves jobs per sweep.""" + ms = [min(n_curves, len(sweep)) for sweep in sweeps] + m, n = max(ms), len(sweeps) + fig, axes = fig_make(m, n, False, sharex=False, sharey=True) + sweeps = [sort_sweep(sweep, "error", reverse)[0] for sweep in sweeps] + xs_trn = [get_vals(sweep, "train_epoch.epoch_ind") for sweep in sweeps] + xs_tst = [get_vals(sweep, "test_epoch.epoch_ind") for sweep in sweeps] + xs_ema = [get_vals(sweep, "test_ema_epoch.epoch_ind") for sweep in sweeps] + xs_max = [get_vals(sweep, "test_ema_epoch.epoch_max") for sweep in sweeps] + ys_trn = [get_vals(sweep, "train_epoch." + metric) for sweep in sweeps] + ys_tst = [get_vals(sweep, "test_epoch." + metric) for sweep in sweeps] + ys_ema = [get_vals(sweep, "test_ema_epoch." + metric) for sweep in sweeps] + ticks = [1, 2, 4, 8, 16, 32, 64, 100] + y_min = min(min(y) for y in ys_ema + ys_tst for y in y) + y_min = ticks[np.argmin(np.asarray(ticks) <= y_min) - 1] + for i, j in [(i, j) for j in range(n) for i in range(ms[j])]: + ax, x_max = axes[i][j], xs_max[j][i][-1] + x_trn, y_trn, e_trn = xs_trn[j][i], ys_trn[j][i], min(ys_trn[j][i]) + x_tst, y_tst, e_tst = xs_tst[j][i], ys_tst[j][i], min(ys_tst[j][i]) + x_ema, y_ema, e_ema = xs_ema[j][i], ys_ema[j][i], min(ys_ema[j][i]) + label, prop = "{} {:5.2f}", {"color": get_color(j), "alpha": 0.8} + ax.plot(x_trn, y_trn, "--", **prop, label=label.format("trn", e_trn)) + ax.plot(x_tst, y_tst, ":", **prop, label=label.format("tst", e_tst)) + ax.plot(x_ema, y_ema, "-", **prop, label=label.format("ema", e_ema)) + ax.plot([x_ema[0], x_ema[-1]], [e_ema, e_ema], "-", color="k", alpha=0.8) + xy_good = [(x, y) for x, y in zip(x_ema, y_ema) if y < 1.01 * e_ema] + ax.scatter(*zip(*xy_good), **prop, s=10) + ax.scatter([np.argmin(y_ema) + 1], e_ema, **prop) + ax.legend(loc="upper right") + ax.set_xlim(right=x_max) + for i, j in [(i, j) for i in range(m) for j in range(n)]: + ax = axes[i][j] + ax.set_xlabel("epoch" if i == m - 1 else "") + ax.set_ylabel(metric if j == 0 else "") + ax.set_yscale("log", base=2) + ax.set_yticks(ticks) + ax.set_yticklabels(ticks) + ax.set_yticks([t * np.sqrt(2) for t in ticks], minor=True) + ax.set_yticklabels([], minor=True) + ax.set_ylim(bottom=y_min, top=100) + ax.yaxis.grid(True, which="minor") + fig_legend(fig, n, names, styles="-", markers="") + return fig + + +def plot_models(sweeps, names, n_models, reverse=False): + """Plots model visualization for up to n_models per sweep.""" + ms = [min(n_models, len(sweep)) for sweep in sweeps] + m, n = max(ms), len(sweeps) + fig, axes = fig_make(m, n, False, sharex=True, sharey=True) + sweeps = [sort_sweep(sweep, "error", reverse)[0] for sweep in sweeps] + for i, j in [(i, j) for j in range(n) for i in range(ms[j])]: + ax, sweep, color = axes[i][j], [sweeps[j][i]], get_color(j) + metrics = ["error", "flops", "params", "acts", "epoch_fw_bw", "resolution"] + vals = [get_vals(sweep, m)[0] for m in metrics] + label = "e = {:.2f}%, f = {:.2f}B\n".format(*vals[0:2]) + label += "p = {:.2f}M, a = {:.2f}M\n".format(*vals[2:4]) + label += "t = {0:.0f}s, r = ${1:d} \\times {1:d}$\n".format(*vals[4:6]) + model_type = get_vals(sweep, "cfg.MODEL.TYPE")[0] + if model_type == "regnet": + metrics = ["GROUP_W", "BOT_MUL", "WA", "W0", "WM", "DEPTH"] + vals = [get_vals(sweep, "cfg.REGNET." + m)[0] for m in metrics] + ws, ds, _, _, _, ws_cont = regnet.generate_regnet(*vals[2:]) + label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws)) + label += "$g={:d}$, $b={:g}$, $w_a={:.1f}$\n".format(*vals[:3]) + label += "$w_0={:d}$, $w_m={:.3f}$".format(*vals[3:5]) + ax.plot(ws_cont, ":", c=color) + elif model_type == "anynet": + metrics = ["anynet_ds", "anynet_ws", "anynet_gs", "anynet_bs"] + ds, ws, gs, bs = [get_vals(sweep, m)[0] for m in metrics] + label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws)) + label += "$g_i = {:s}$\n$b_i = {:s}$".format(str(gs), str(bs)) + elif model_type == "effnet": + metrics = ["effnet_ds", "effnet_ws", "effnet_ss", "effnet_bs"] + ds, ws, ss, bs = [get_vals(sweep, m)[0] for m in metrics] + label += "$d_i = {:s}$\n$w_i = {:s}$\n".format(str(ds), str(ws)) + label += "$s_i = {:s}$\n$b_i = {:s}$".format(str(ss), str(bs)) + else: + raise AssertionError("Unknown model type" + model_type) + ws_all = [w for ws in [[w] * d for d, w in zip(ds, ws)] for w in ws] + ds_cum = np.cumsum([0] + ds[0:-1]) + ax.plot(ws_all, "o-", c=color, markersize=plt.rcParams["lines.markersize"] - 1) + ax.plot(ds_cum, ws, "o", c="k", fillstyle="none", label=label) + ax.legend(loc="lower right", markerscale=0, handletextpad=0, handlelength=0) + for i, j in [(i, j) for i in range(m) for j in range(n)]: + ax = axes[i][j] + ax.set_xlabel("block index" if i == m - 1 else "") + ax.set_ylabel("width" if j == 0 else "") + ax.set_yscale("log", base=2) + ax.yaxis.set_major_formatter(ticker.ScalarFormatter()) + ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) + fig_legend(fig, n, names, styles="-") + return fig + + +# Set global plot style and colors on import +set_plot_style() +set_colors() +set_color_fill() diff --git a/requirements.txt b/requirements.txt index c72334c..2c1bd6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ black==19.3b0 isort==4.3.21 iopath flake8 +pyyaml matplotlib numpy opencv-python==4.2.0.34 @@ -9,3 +10,4 @@ parameterized setuptools simplejson yacs +yattag diff --git a/tools/sweep_analyze.py b/tools/sweep_analyze.py new file mode 100644 index 0000000..f4a6fec --- /dev/null +++ b/tools/sweep_analyze.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Analyze results of a sweep.""" + +import os +import time + +import matplotlib.pyplot as plt +import pycls.sweep.analysis as analysis +import pycls.sweep.config as sweep_config +import pycls.sweep.plotting as plotting +from pycls.sweep.config import sweep_cfg +from pycls.sweep.htmlbook import Htmlbook + + +def sweep_analyze(): + """Analyzes results of a sweep.""" + start_time = time.time() + analyze_cfg = sweep_cfg.ANALYZE + sweep_dir = os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME) + print("Generating sweepbook for {:s}... ".format(sweep_dir), end="", flush=True) + # Initialize Htmlbook for results + h = Htmlbook(sweep_cfg.NAME) + # Output sweep config + h.add_section("Config") + with open(sweep_cfg.SWEEP_CFG_FILE, "r") as f: + sweep_cfg_raw = f.read() + h.add_details("sweep_cfg", sweep_cfg_raw) + h.add_details("sweep_cfg_full", str(sweep_cfg)) + # Load sweep and plot EDF + names = [sweep_cfg.NAME] + analyze_cfg.EXTRA_SWEEP_NAMES + files = [os.path.join(sweep_cfg.ROOT_DIR, name, "sweep.json") for name in names] + sweeps = [analysis.load_sweep(file) for file in files] + names = [os.path.basename(name) for name in names] + assert all(len(sweep) for sweep in sweeps), "Loaded sweep cannot be empty." + h.add_section("EDF") + h.add_plot(plotting.plot_edf(sweeps, names)) + for sweep, name in zip(sweeps, names): + h.add_details(name, analysis.describe_sweep(sweep)) + # Pre filter sweep according to pre_filters and plot EDF + pre_filters = analyze_cfg.PRE_FILTERS + if pre_filters: + sweeps = [analysis.apply_filters(sweep, pre_filters) for sweep in sweeps] + assert all(len(sweep) for sweep in sweeps), "Filtered sweep cannot be empty." + h.add_section("EDF Filtered") + h.add_plot(plotting.plot_edf(sweeps, names)) + for sweep, name in zip(sweeps, names): + h.add_details(name, analysis.describe_sweep(sweep)) + # Split sweep according to split_filters and plot EDF + split_filters = analyze_cfg.SPLIT_FILTERS + if split_filters and len(names) == 1: + names = list(split_filters.keys()) + sweeps = [analysis.apply_filters(sweeps[0], f) for f in split_filters.values()] + assert all(len(sweep) for sweep in sweeps), "Split sweep cannot be empty." + h.add_section("EDF Split") + h.add_plot(plotting.plot_edf(sweeps, names)) + for sweep, name in zip(sweeps, names): + h.add_details(name, analysis.describe_sweep(sweep)) + # Plot metric scatter plots + metrics = analyze_cfg.METRICS + plot_metric_trends = analyze_cfg.PLOT_METRIC_TRENDS and len(sweeps) > 1 + if metrics and (analyze_cfg.PLOT_METRIC_VALUES or plot_metric_trends): + h.add_section("Metrics") + filters = [analysis.get_filters(sweep, metrics) for sweep in sweeps] + if analyze_cfg.PLOT_METRIC_VALUES: + h.add_plot(plotting.plot_values(sweeps, names, metrics, filters)) + if plot_metric_trends: + h.add_plot(plotting.plot_trends(sweeps, names, metrics, filters)) + # Plot complexity scatter plots + complexity = analyze_cfg.COMPLEXITY + plot_complexity_trends = analyze_cfg.PLOT_COMPLEXITY_TRENDS and len(sweeps) > 1 + if complexity and (analyze_cfg.PLOT_COMPLEXITY_VALUES or plot_complexity_trends): + h.add_section("Complexity") + filters = [analysis.get_filters(sweep, complexity) for sweep in sweeps] + if analyze_cfg.PLOT_COMPLEXITY_VALUES: + h.add_plot(plotting.plot_values(sweeps, names, complexity, filters)) + if plot_complexity_trends: + h.add_plot(plotting.plot_trends(sweeps, names, complexity, filters)) + # Plot best/worst error curves + n = analyze_cfg.PLOT_CURVES_BEST + if n > 0: + h.add_section("Best Errors") + h.add_plot(plotting.plot_curves(sweeps, names, "top1_err", n, False)) + n = analyze_cfg.PLOT_CURVES_WORST + if n > 0: + h.add_section("Worst Errors") + h.add_plot(plotting.plot_curves(sweeps, names, "top1_err", n, True)) + # Plot best/worst models + n = analyze_cfg.PLOT_MODELS_BEST + if n > 0: + h.add_section("Best Models") + h.add_plot(plotting.plot_models(sweeps, names, n, False)) + n = analyze_cfg.PLOT_MODELS_WORST + if n > 0: + h.add_section("Worst Models") + h.add_plot(plotting.plot_models(sweeps, names, n, True)) + # Output Htmlbook and finalize analysis + h.to_file(os.path.join(sweep_dir, "analysis.html")) + plt.close("all") + print("Done [t={:.1f}s]".format(time.time() - start_time)) + + +def main(): + desc = "Analyze results of a sweep." + sweep_config.load_cfg_fom_args(desc) + sweep_cfg.freeze() + sweep_analyze() + + +if __name__ == "__main__": + main()