From 250eb178dc5479fae73f016e6a663eade52baf59 Mon Sep 17 00:00:00 2001 From: lucasb-eyer Date: Mon, 27 Nov 2017 12:18:19 +0100 Subject: [PATCH] Use python logging in training. This way, we get the training logs in the experiment_root too! TODO: Maybe also do that in embed and eval? --- common.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 43 +++++++----- 2 files changed, 225 insertions(+), 17 deletions(-) diff --git a/common.py b/common.py index 7743d37..5ef5722 100644 --- a/common.py +++ b/common.py @@ -1,6 +1,7 @@ """ A bunch of general utilities shared by train/embed/eval """ from argparse import ArgumentTypeError +import logging import os import numpy as np @@ -154,3 +155,201 @@ def fid_to_image(fid, pid, image_root, image_size): image_resized = tf.image.resize_images(image_decoded, image_size) return image_resized, fid, pid + + +def get_logging_dict(name): + return { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' + }, + }, + 'handlers': { + 'stderr': { + 'level': 'INFO', + 'formatter': 'standard', + 'class': 'common.ColorStreamHandler', + 'stream': 'ext://sys.stderr', + }, + 'logfile': { + 'level': 'DEBUG', + 'formatter': 'standard', + 'class': 'logging.FileHandler', + 'filename': name + '.log', + 'mode': 'a', + } + }, + 'loggers': { + '': { + 'handlers': ['stderr', 'logfile'], + 'level': 'DEBUG', + 'propagate': True + }, + + # extra ones to shut up. + 'tensorflow': { + 'handlers': ['stderr', 'logfile'], + 'level': 'INFO', + }, + } + } + + +# Source for the remainder: https://gist.github.com/mooware/a1ed40987b6cc9ab9c65 +# Fixed some things mentioned in the comments there. + +# colored stream handler for python logging framework (use the ColorStreamHandler class). +# +# based on: +# http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output/1336640#1336640 + +# how to use: +# i used a dict-based logging configuration, not sure what else would work. +# +# import logging, logging.config, colorstreamhandler +# +# _LOGCONFIG = { +# "version": 1, +# "disable_existing_loggers": False, +# +# "handlers": { +# "console": { +# "class": "colorstreamhandler.ColorStreamHandler", +# "stream": "ext://sys.stderr", +# "level": "INFO" +# } +# }, +# +# "root": { +# "level": "INFO", +# "handlers": ["console"] +# } +# } +# +# logging.config.dictConfig(_LOGCONFIG) +# mylogger = logging.getLogger("mylogger") +# mylogger.warning("foobar") + +# Copyright (c) 2014 Markus Pointner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +class _AnsiColorStreamHandler(logging.StreamHandler): + DEFAULT = '\x1b[0m' + RED = '\x1b[31m' + GREEN = '\x1b[32m' + YELLOW = '\x1b[33m' + CYAN = '\x1b[36m' + + CRITICAL = RED + ERROR = RED + WARNING = YELLOW + INFO = DEFAULT # GREEN + DEBUG = CYAN + + @classmethod + def _get_color(cls, level): + if level >= logging.CRITICAL: return cls.CRITICAL + elif level >= logging.ERROR: return cls.ERROR + elif level >= logging.WARNING: return cls.WARNING + elif level >= logging.INFO: return cls.INFO + elif level >= logging.DEBUG: return cls.DEBUG + else: return cls.DEFAULT + + def __init__(self, stream=None): + logging.StreamHandler.__init__(self, stream) + + def format(self, record): + text = logging.StreamHandler.format(self, record) + color = self._get_color(record.levelno) + return (color + text + self.DEFAULT) if self.is_tty() else text + + def is_tty(self): + isatty = getattr(self.stream, 'isatty', None) + return isatty and isatty() + + +class _WinColorStreamHandler(logging.StreamHandler): + # wincon.h + FOREGROUND_BLACK = 0x0000 + FOREGROUND_BLUE = 0x0001 + FOREGROUND_GREEN = 0x0002 + FOREGROUND_CYAN = 0x0003 + FOREGROUND_RED = 0x0004 + FOREGROUND_MAGENTA = 0x0005 + FOREGROUND_YELLOW = 0x0006 + FOREGROUND_GREY = 0x0007 + FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified. + FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED + + BACKGROUND_BLACK = 0x0000 + BACKGROUND_BLUE = 0x0010 + BACKGROUND_GREEN = 0x0020 + BACKGROUND_CYAN = 0x0030 + BACKGROUND_RED = 0x0040 + BACKGROUND_MAGENTA = 0x0050 + BACKGROUND_YELLOW = 0x0060 + BACKGROUND_GREY = 0x0070 + BACKGROUND_INTENSITY = 0x0080 # background color is intensified. + + DEFAULT = FOREGROUND_WHITE + CRITICAL = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY + ERROR = FOREGROUND_RED | FOREGROUND_INTENSITY + WARNING = FOREGROUND_YELLOW | FOREGROUND_INTENSITY + INFO = FOREGROUND_GREEN + DEBUG = FOREGROUND_CYAN + + @classmethod + def _get_color(cls, level): + if level >= logging.CRITICAL: return cls.CRITICAL + elif level >= logging.ERROR: return cls.ERROR + elif level >= logging.WARNING: return cls.WARNING + elif level >= logging.INFO: return cls.INFO + elif level >= logging.DEBUG: return cls.DEBUG + else: return cls.DEFAULT + + def _set_color(self, code): + import ctypes + ctypes.windll.kernel32.SetConsoleTextAttribute(self._outhdl, code) + + def __init__(self, stream=None): + logging.StreamHandler.__init__(self, stream) + # get file handle for the stream + import ctypes, ctypes.util + # for some reason find_msvcrt() sometimes doesn't find msvcrt.dll on my system? + crtname = ctypes.util.find_msvcrt() + if not crtname: + crtname = ctypes.util.find_library("msvcrt") + crtlib = ctypes.cdll.LoadLibrary(crtname) + self._outhdl = crtlib._get_osfhandle(self.stream.fileno()) + + def emit(self, record): + color = self._get_color(record.levelno) + self._set_color(color) + logging.StreamHandler.emit(self, record) + self._set_color(self.FOREGROUND_WHITE) + +# select ColorStreamHandler based on platform +import platform +if platform.system() == 'Windows': + ColorStreamHandler = _WinColorStreamHandler +else: + ColorStreamHandler = _AnsiColorStreamHandler diff --git a/train.py b/train.py index 11001e7..9294494 100755 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser from datetime import timedelta from importlib import import_module +import logging.config import os from signal import SIGINT, SIGTERM import sys @@ -195,9 +196,9 @@ def main(): # If the experiment directory exists already, we bail in fear. if os.path.exists(args.experiment_root): if os.listdir(args.experiment_root): - print('The directory {} already exists and is not empty. If ' - 'you want to resume training, append --resume to your ' - 'call.'.format(args.experiment_root)) + print('The directory {} already exists and is not empty.' + ' If you want to resume training, append --resume to' + ' your call.'.format(args.experiment_root)) exit(1) else: os.makedirs(args.experiment_root) @@ -207,19 +208,23 @@ def main(): with open(args_file, 'w') as f: json.dump(vars(args), f, ensure_ascii=False, indent=2, sort_keys=True) + log_file = os.path.join(args.experiment_root, "train") + logging.config.dictConfig(common.get_logging_dict(log_file)) + log = logging.getLogger('train') + # Also show all parameter values at the start, for ease of reading logs. - print('Training using the following parameters:') + log.info('Training using the following parameters:') for key, value in sorted(vars(args).items()): - print('{}: {}'.format(key, value)) + log.info('{}: {}'.format(key, value)) # Check them here, so they are not required when --resume-ing. if not args.train_set: parser.print_help() - print("You did not specify the `train_set` argument!") + log.error("You did not specify the `train_set` argument!") sys.exit(1) if not args.image_root: parser.print_help() - print("You did not specify the required `image_root` argument!") + log.error("You did not specify the required `image_root` argument!") sys.exit(1) # Load the data from the CSV file. @@ -351,7 +356,7 @@ def main(): if args.resume: # In case we're resuming, simply load the full checkpoint to init. last_checkpoint = tf.train.latest_checkpoint(args.experiment_root) - print('Restoring from checkpoint: {}'.format(last_checkpoint)) + log.info('Restoring from checkpoint: {}'.format(last_checkpoint)) checkpoint_saver.restore(sess, last_checkpoint) else: # But if we're starting from scratch, we may need to load some @@ -370,7 +375,7 @@ def main(): summary_writer = tf.summary.FileWriter(args.experiment_root, sess.graph) start_step = sess.run(global_step) - print('Starting training from iteration {}.'.format(start_step)) + log.info('Starting training from iteration {}.'.format(start_step)) # Finally, here comes the main-loop. This `Uninterrupt` is a handy # utility such that an iteration still finishes on Ctrl+C and we can @@ -397,13 +402,17 @@ def main(): # Do a huge print out of the current progress. seconds_todo = (args.train_iterations - step) * elapsed_time - print('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' - 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( - step, float(np.min(b_loss)), float(np.mean(b_loss)), - float(np.max(b_loss)), - args.batch_k-1, float(b_prec_at_k), - timedelta(seconds=int(seconds_todo)), elapsed_time), - flush=True) + log.info('iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, ' + 'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it)'.format( + step, + float(np.min(b_loss)), + float(np.mean(b_loss)), + float(np.max(b_loss)), + args.batch_k-1, float(b_prec_at_k), + timedelta(seconds=int(seconds_todo)), + elapsed_time)) + sys.stdout.flush() + sys.stderr.flush() # Save a checkpoint of training every so often. if (args.checkpoint_frequency > 0 and @@ -413,7 +422,7 @@ def main(): # Stop the main-loop at the end of the step, if requested. if u.interrupted: - print("Interrupted on request!") + log.info("Interrupted on request!") break # Store one final checkpoint. This might be redundant, but it is crucial