From 06d3a29307d4641c4fbd58e82861de31fc76af1d Mon Sep 17 00:00:00 2001 From: Zach White Date: Sun, 28 Jan 2024 17:32:09 -0800 Subject: [PATCH] add type hints to milc (now we'll never have a bug again!) --- .gitignore | 1 + ci_tests | 6 + milc/__init__.py | 11 +- milc/_in_argv.py | 5 +- milc/_sparkline.py | 20 +++- milc/ansi.py | 6 +- milc/attrdict.py | 31 +++--- milc/configuration.py | 70 ++++++------ milc/milc.py | 225 ++++++++++++++++++++++++-------------- milc/questions.py | 138 +++++++++++++++-------- milc/subcommand/config.py | 20 ++-- py.typed | 0 setup.py | 2 + 13 files changed, 345 insertions(+), 190 deletions(-) create mode 100644 py.typed diff --git a/.gitignore b/.gitignore index 595e6e5..d9af7ea 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ milc.egg-info venv site .venv +.vscode diff --git a/ci_tests b/ci_tests index d04427b..7cc1c6d 100755 --- a/ci_tests +++ b/ci_tests @@ -43,6 +43,12 @@ def main(cli): build_ok = False cli.log.error('Improperly formatted code. Please run this: yapf -i -r .') + cli.log.info('Running mypy...') + cmd = ['mypy', '--strict', 'milc'] + result = run(cmd, stdin=DEVNULL) + if result.returncode != 0: + build_ok = False + if build_ok: cli.log.info('{fg_green}All tests passed!') return True diff --git a/milc/__init__.py b/milc/__init__.py index 346959d..506b3ec 100644 --- a/milc/__init__.py +++ b/milc/__init__.py @@ -18,6 +18,7 @@ import os import sys import warnings +from typing import Optional from .emoji import EMOJI_LOGLEVELS from .milc import MILC @@ -32,7 +33,13 @@ cli = MILC() -def set_metadata(*, name=None, author=None, version=None, logger=None): +def set_metadata( + *, + name: Optional[str] = None, + author: Optional[str] = None, + version: Optional[str] = None, + logger: Optional[logging.Logger] = None, +) -> MILC: """Set metadata about your program. This allows you to set the application's name, version, and/or author @@ -48,6 +55,8 @@ def set_metadata(*, name=None, author=None, version=None, logger=None): cli = MILC(name, version, author, logger) + return cli + # Extra stuff people can import from ._sparkline import sparkline # noqa diff --git a/milc/_in_argv.py b/milc/_in_argv.py index 8ca0ff9..1227fd3 100644 --- a/milc/_in_argv.py +++ b/milc/_in_argv.py @@ -1,7 +1,8 @@ import sys +from typing import Optional -def _in_argv(argument): +def _in_argv(argument: str) -> bool: """Returns true if the argument is found is sys.argv. Since long options can be passed as either '--option value' or '--option=value' we need to check for both forms. @@ -13,7 +14,7 @@ def _in_argv(argument): return False -def _index_argv(argument): +def _index_argv(argument: str) -> Optional[int]: """Returns the location of the argument in sys.argv, or None. Since long options can be passed as either '--option value' or '--option=value' we need to check for both forms. diff --git a/milc/_sparkline.py b/milc/_sparkline.py index 2f264e6..00bf61e 100644 --- a/milc/_sparkline.py +++ b/milc/_sparkline.py @@ -3,19 +3,35 @@ """ from decimal import Decimal from math import inf +from typing import Any, List, Optional, TypeGuard from milc import cli spark_chars = '▁▂▃▄▅▆▇█' -def is_number(i): +def is_number(i: Any) -> TypeGuard[bool]: """Returns true if i is a number. Used to filter non-numbers from a list. """ return isinstance(i, (int, float, Decimal)) -def sparkline(number_list, *, min_value=None, max_value=None, highlight_low=-inf, highlight_high=inf, highlight_low_color='', highlight_high_color='', negative_color='{fg_red}', positive_color='', highlight_low_reset='{fg_reset}', highlight_high_reset='{fg_reset}', negative_reset='{fg_reset}', positive_reset='{fg_reset}'): +def sparkline( + number_list: List[Optional[int]], + *, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + highlight_low: float = -inf, + highlight_high: float = inf, + highlight_low_color: str = '', + highlight_high_color: str = '', + negative_color: str = '{fg_red}', + positive_color: str = '', + highlight_low_reset: str = '{fg_reset}', + highlight_high_reset: str = '{fg_reset}', + negative_reset: str = '{fg_reset}', + positive_reset: str = '{fg_reset}', +) -> str: """Display a sparkline from a sequence of numbers. If you wish to exclude extreme values, or you want to limit the set of characters used, you can adjust `min_value` and `max_value` to your own values. Values between your actual min/max will exclude datapoints, while values outside your actual min/max will compress your data into fewer sparks. diff --git a/milc/ansi.py b/milc/ansi.py index 3b81376..e7e5a74 100644 --- a/milc/ansi.py +++ b/milc/ansi.py @@ -4,6 +4,7 @@ import re import logging import colorama +from typing import Any from .emoji import EMOJI_LOGLEVELS @@ -41,7 +42,7 @@ ansi_colors[prefix + '_' + color.lower()] = getattr(obj, color) -def format_ansi(text): +def format_ansi(text: str) -> str: """Return a copy of text with certain strings replaced with ansi. """ # Avoid .format() so we don't have to worry about the log content @@ -59,9 +60,10 @@ def format_ansi(text): class MILCFormatter(logging.Formatter): """Formats log records per the MILC configuration. """ - def format(self, record): + def format(self, record: Any) -> Any: if ansi_config['unicode'] and record.levelname in EMOJI_LOGLEVELS: record.levelname = format_ansi(EMOJI_LOGLEVELS[record.levelname]) msg = super().format(record) + return format_ansi(msg) diff --git a/milc/attrdict.py b/milc/attrdict.py index 66b1f90..5e151ce 100644 --- a/milc/attrdict.py +++ b/milc/attrdict.py @@ -1,43 +1,46 @@ +from typing import Any, Dict, List + + class AttrDict(object): """A dictionary that can also be accessed by attribute. """ - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return self._data.__contains__(key) - def __iter__(self): + def __iter__(self) -> Any: return self._data.__iter__() - def __len__(self): + def __len__(self) -> Any: return self._data.__len__() - def __repr__(self): + def __repr__(self) -> Any: return self._data.__repr__() - def keys(self): + def keys(self) -> Any: return self._data.keys() - def items(self): + def items(self) -> Any: return self._data.items() - def values(self): + def values(self) -> Any: return self._data.values() - def __init__(self, *args, **kwargs): - self._data = {} + def __init__(self, *args: List[Any], **kwargs: Dict[Any, Any]) -> None: + self._data: Dict[Any, Any] = {} - def __getattr__(self, key): + def __getattr__(self, key: Any) -> Any: return self.__getitem__(key) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: """Returns an item. """ return self._data[key] - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: self._data[key] = value self.__setattr__(key, value) - def __delitem__(self, key): + def __delitem__(self, key: Any) -> None: if key in self._data: del self._data[key] @@ -48,7 +51,7 @@ class SparseAttrDict(AttrDict): This class never raises IndexError, instead it will return None if a key does not yet exist. """ - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: """Returns an item, creating it if it doesn't already exist """ if key not in self._data: diff --git a/milc/configuration.py b/milc/configuration.py index 22d5f2a..89d8502 100644 --- a/milc/configuration.py +++ b/milc/configuration.py @@ -1,3 +1,5 @@ +from typing import Any, Hashable, List + from .attrdict import AttrDict @@ -7,7 +9,7 @@ class Configuration(AttrDict): This class never raises IndexError, instead it will return None if a section or option does not yet exist. """ - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> Any: """Returns a config section, creating it if it doesn't exist yet. """ if key not in self._data: @@ -17,11 +19,11 @@ def __getitem__(self, key): class ConfigurationSection(Configuration): - def __init__(self, parent, *args, **kwargs): + def __init__(self, parent: AttrDict, *args: Any, **kwargs: Any) -> None: super(ConfigurationSection, self).__init__(*args, **kwargs) self._parent = parent - def __getitem__(self, key): + def __getitem__(self, key: Hashable) -> Any: """Returns a config value, pulling from the `user` section as a fallback. This is called when the attribute is accessed either via the get method or through [ ] index. """ @@ -33,7 +35,7 @@ def __getitem__(self, key): return None - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: """Returns the config value from the `user` section. This is called when the attribute is accessed via dot notation but does not exist. """ @@ -42,7 +44,7 @@ def __getattr__(self, key): return None - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: """Sets dictionary value when an attribute is set. """ super().__setattr__(key, value) @@ -54,7 +56,9 @@ def __setattr__(self, key, value): class SubparserWrapper(object): """Wrap subparsers so we can track what options the user passed. """ - def __init__(self, cli, submodule, subparser): + + # We type `cli` as Any instead of MILC to avoid a circular import + def __init__(self, cli: Any, submodule: Any, subparser: Any) -> None: self.cli = cli self.submodule = submodule self.subparser = subparser @@ -63,66 +67,70 @@ def __init__(self, cli, submodule, subparser): if not hasattr(self, attr): setattr(self, attr, getattr(subparser, attr)) - def completer(self, completer): + def completer(self, completer: Any) -> None: """Add an arpcomplete completer to this subcommand. """ self.subparser.completer = completer - def add_argument(self, *args, **kwargs): + def add_argument(self, *args: Any, **kwargs: Any) -> None: """Add an argument for this subcommand. This also stores the default for the argument in `self.cli.default_arguments`. """ if kwargs.get('action') == 'store_boolean': # Store boolean will call us again with the enable/disable flag arguments - return handle_store_boolean(self, *args, **kwargs) + handle_store_boolean(self, *args, **kwargs) - completer = None + else: + completer = None - if kwargs.get('completer'): - completer = kwargs['completer'] - del kwargs['completer'] + if kwargs.get('completer'): + completer = kwargs['completer'] + del kwargs['completer'] - self.cli.acquire_lock() - argument_name = get_argument_name(self.cli._arg_parser, *args, **kwargs) + self.cli.acquire_lock() + argument_name = get_argument_name(self.cli._arg_parser, *args, **kwargs) - if completer: - self.subparser.add_argument(*args, **kwargs).completer = completer - else: - self.subparser.add_argument(*args, **kwargs) + if completer: + self.subparser.add_argument(*args, **kwargs).completer = completer + else: + self.subparser.add_argument(*args, **kwargs) - if kwargs.get('action') == 'store_false': - self.cli._config_store_false.append(argument_name) + if kwargs.get('action') == 'store_false': + self.cli._config_store_false.append(argument_name) - if kwargs.get('action') == 'store_true': - self.cli._config_store_true.append(argument_name) + if kwargs.get('action') == 'store_true': + self.cli._config_store_true.append(argument_name) - if self.submodule not in self.cli.default_arguments: - self.cli.default_arguments[self.submodule] = {} + if self.submodule not in self.cli.default_arguments: + self.cli.default_arguments[self.submodule] = {} - self.cli.default_arguments[self.submodule][argument_name] = kwargs.get('default') - self.cli.release_lock() + self.cli.default_arguments[self.submodule][argument_name] = kwargs.get('default') + self.cli.release_lock() -def get_argument_strings(arg_parser, *args, **kwargs): +def get_argument_strings(arg_parser: Any, *args: Any, **kwargs: Any) -> List[str]: """Takes argparse arguments and returns a list of argument strings or positional names. """ try: - return arg_parser._get_optional_kwargs(*args, **kwargs)['option_strings'] + return arg_parser._get_optional_kwargs(*args, **kwargs)['option_strings'] # type: ignore[no-any-return] + except ValueError: return [arg_parser._get_positional_kwargs(*args, **kwargs)['dest']] -def get_argument_name(arg_parser, *args, **kwargs): +def get_argument_name(arg_parser: Any, *args: Any, **kwargs: Any) -> Any: """Takes argparse arguments and returns the dest name. """ try: return arg_parser._get_optional_kwargs(*args, **kwargs)['dest'] + except ValueError: return arg_parser._get_positional_kwargs(*args, **kwargs)['dest'] -def handle_store_boolean(self, *args, **kwargs): +# FIXME: We should not be using self in this way +def handle_store_boolean(self: Any, *args: Any, **kwargs: Any) -> Any: """Does the add_argument for action='store_boolean'. """ disabled_args = None diff --git a/milc/milc.py b/milc/milc.py index ae125cf..3757a35 100644 --- a/milc/milc.py +++ b/milc/milc.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # coding=utf-8 import argparse import logging @@ -12,28 +11,38 @@ from pathlib import Path from platform import platform from tempfile import NamedTemporaryFile +from types import TracebackType +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple try: import threading except ImportError: - threading = None + threading = None # type: ignore[assignment] import argcomplete import colorama -from appdirs import user_config_dir -from halo import Halo -from spinners.spinners import Spinners +from appdirs import user_config_dir # type: ignore[import-untyped] +from halo import Halo # type: ignore[import-untyped] +from spinners.spinners import Spinners # type: ignore[import-untyped] from .ansi import MILCFormatter, ansi_colors, ansi_config, ansi_escape, format_ansi from .attrdict import AttrDict from .configuration import Configuration, SubparserWrapper, get_argument_name, get_argument_strings, handle_store_boolean from ._in_argv import _in_argv, _index_argv +# FIXME: Replace Callable[..., Any] with better definitions + class MILC(object): """MILC - An Opinionated Batteries Included Framework """ - def __init__(self, name=None, version=None, author=None, logger=None): + def __init__( + self, + name: Optional[str] = None, + version: Optional[str] = None, + author: Optional[str] = None, + logger: Optional[logging.Logger] = None, + ) -> None: """Initialize the MILC object. """ # Set some defaults @@ -54,22 +63,21 @@ def __init__(self, name=None, version=None, author=None, logger=None): self.prog_name = name self.version = version self.author = author - self._config_store_true = [] - self._config_store_false = [] - self._description = None - self._entrypoint = None - self._spinners = {} + self._config_store_true: Sequence[str] = [] + self._config_store_false: Sequence[str] = [] + self._entrypoint: Callable[[Any], Any] = lambda _: None + self._spinners: Dict[str, Dict[str, int | Sequence[str]]] = {} self._subcommand = None self._inside_context_manager = False self.ansi = ansi_colors - self.arg_only = {} + self.arg_only: Dict[str, List[str]] = {} self.config_file = self.find_config_file() - self.default_arguments = {} + self.default_arguments: Dict[str, Dict[str, Optional[str]]] = {} self.platform = platform() self.interactive = sys.stdin.isatty() self.release_lock() - self._deprecated_arguments = {} - self._deprecated_commands = {} + self._deprecated_arguments: Dict[str, str] = {} + self._deprecated_commands: Dict[str, str] = {} # Initialize all the things self.initialize_config() @@ -77,24 +85,25 @@ def __init__(self, name=None, version=None, author=None, logger=None): self.initialize_logging(logger) @property - def config_dir(self): + def config_dir(self) -> Path: return self.config_file.parent @property - def description(self): - return self._description + def description(self) -> Optional[str]: + return self._arg_parser.description @description.setter - def description(self, value): - self._description = self._arg_parser.description = value + def description(self, value: str) -> None: + self._arg_parser.description = value - def argv_name(self): + def argv_name(self) -> str: """Returns the name of our program by examining argv. """ app_name = sys.argv[0][:-3] if sys.argv[0].endswith('.py') else sys.argv[0] + return os.path.split(app_name)[-1] - def echo(self, text, *args, **kwargs): + def echo(self, text: str, *args: Any, **kwargs: Any) -> None: """Print colorized text to stdout. ANSI color strings (such as {fg_blue}) will be converted into ANSI @@ -104,17 +113,26 @@ def echo(self, text, *args, **kwargs): If *args or **kwargs are passed they will be used to %-format the strings. """ if args and kwargs: - raise RuntimeError('You can only specify *args or **kwargs, not both!') + raise ValueError('You can only specify *args or **kwargs, not both!') - args = args or kwargs - text = format_ansi(text % args) + if args: + text = format_ansi(text % args) + else: + text = format_ansi(text % kwargs) if not self.config.general.color: text = ansi_escape.sub('', text) print(text) - def run(self, command, capture_output=True, combined_output=False, text=True, **kwargs): + def run( + self, + command: Sequence[str], + capture_output: bool = True, + combined_output: bool = False, + text: bool = True, + **kwargs: Any, + ) -> subprocess.CompletedProcess[bytes | str]: """Run a command using `subprocess.run`, but using some different defaults. Unlike subprocess.run you must supply a sequence of arguments. You can use `shlex.split()` to build this from a string. @@ -146,8 +164,7 @@ def run(self, command, capture_output=True, combined_output=False, text=True, ** # stdin is broken so things like milc.questions no longer work. # We pass `stdin=subprocess.DEVNULL` by default to prevent that. if 'windows' in self.platform.lower(): - safecmd = map(shlex.quote, command) - safecmd = ' '.join(safecmd) + safecmd = ' '.join(map(shlex.quote, command)) command = [os.environ['SHELL'], '-c', safecmd] if 'stdin' not in kwargs: @@ -172,7 +189,7 @@ def run(self, command, capture_output=True, combined_output=False, text=True, ** return subprocess.run(command, **kwargs) - def initialize_argparse(self): + def initialize_argparse(self) -> None: """Prepare to process arguments from sys.argv. """ kwargs = { @@ -181,41 +198,44 @@ def initialize_argparse(self): } self.acquire_lock() - self.subcommands = {} - self._subparsers = None - self.argwarn = argcomplete.warn + + self.subcommands: Dict[str, Any] = {} + self._subparsers: Optional[Any] = None # FIXME: Find a better type signature + self.argwarn = argcomplete.warn # type: ignore[attr-defined] self.args = AttrDict() self.args_passed = AttrDict() - self._arg_parser = argparse.ArgumentParser(**kwargs) + self._arg_parser = argparse.ArgumentParser(**kwargs) # type: ignore[arg-type] self.set_defaults = self._arg_parser.set_defaults + self.release_lock() - def print_help(self, *args, **kwargs): + def print_help(self, *args: Any, **kwargs: Any) -> None: """Print a help message for the main program or subcommand, depending on context. """ if self._subcommand: - return self.subcommands[self._subcommand.__name__].print_help(*args, **kwargs) - - return self._arg_parser.print_help(*args, **kwargs) + self.subcommands[self._subcommand.__name__].print_help(*args, **kwargs) + else: + self._arg_parser.print_help(*args, **kwargs) - def print_usage(self, *args, **kwargs): + def print_usage(self, *args: Any, **kwargs: Any) -> None: """Print brief description of how the main program or subcommand is invoked, depending on context. """ if self._subcommand: - return self.subcommands[self._subcommand.__name__].print_usage(*args, **kwargs) - - return self._arg_parser.print_usage(*args, **kwargs) + self.subcommands[self._subcommand.__name__].print_usage(*args, **kwargs) + else: + self._arg_parser.print_usage(*args, **kwargs) - def log_deprecated_warning(self, item_type, name, reason): + def log_deprecated_warning(self, item_type: str, name: str, reason: str) -> None: """Logs a warning with a custom message if a argument or command is deprecated. """ self.log.warning("Warning: %s '%s' is deprecated:\n\t%s", item_type, name, reason) - def add_argument(self, *args, **kwargs): + def add_argument(self, *args: Any, **kwargs: Any) -> None: """Wrapper to add arguments and track whether they were passed on the command line. """ if 'action' in kwargs and kwargs['action'] == 'store_boolean': - return handle_store_boolean(self, *args, **kwargs) + handle_store_boolean(self, *args, **kwargs) + return arg_name = get_argument_name(self._arg_parser, *args, **kwargs) arg_strings = get_argument_strings(self._arg_parser, *args, **kwargs) @@ -229,7 +249,7 @@ def add_argument(self, *args, **kwargs): self.acquire_lock() if completer: - self._arg_parser.add_argument(*args, **kwargs).completer = completer + self._arg_parser.add_argument(*args, **kwargs).completer = completer # type: ignore[attr-defined] else: self._arg_parser.add_argument(*args, **kwargs) @@ -251,24 +271,27 @@ def add_argument(self, *args, **kwargs): self.release_lock() - def initialize_logging(self, logger): + def initialize_logging(self, logger: Optional[logging.Logger]) -> None: """Prepare the defaults for the logging infrastructure. """ if not logger: logger = logging.getLogger(self.__class__.__name__) self.acquire_lock() + self.log_file = None self.log_file_mode = 'a' - self.log_file_handler = None + self.log_file_handler: Optional[logging.FileHandler] = None self.log_print = True self.log_print_to = sys.stderr self.log_print_level = logging.INFO self.log_file_level = logging.INFO self.log_level = logging.INFO self.log = logger + self.log.setLevel(logging.DEBUG) logging.root.setLevel(logging.DEBUG) + self.release_lock() self.add_argument('-V', '--version', version=self.version, action='version', help='Display the version and exit') @@ -282,9 +305,10 @@ def initialize_logging(self, logger): self.add_argument('--unicode', action='store_boolean', default=ansi_config['unicode'], help='unicode loglevels') self.add_argument('--interactive', action='store_true', help='Force interactive mode even when stdout is not a tty.') self.add_argument('--config-file', help='The location for the configuration file') + self.arg_only['config_file'] = ['general'] - def add_subparsers(self, title='Sub-commands', **kwargs): + def add_subparsers(self, title: str = 'Sub-commands', **kwargs: Any) -> None: if self._inside_context_manager: raise RuntimeError('You must run this before the with statement!') @@ -292,24 +316,27 @@ def add_subparsers(self, title='Sub-commands', **kwargs): self._subparsers = self._arg_parser.add_subparsers(title=title, dest='subparsers', **kwargs) self.release_lock() - def acquire_lock(self, blocking=True): + def acquire_lock(self, blocking: bool = True) -> bool: """Acquire the MILC lock for exclusive access to properties. """ if self._lock: - self._lock.acquire(blocking) + return self._lock.acquire(blocking) - def release_lock(self): + return True + + def release_lock(self) -> None: """Release the MILC lock. """ if self._lock: self._lock.release() @lru_cache(maxsize=None) - def find_config_file(self): + def find_config_file(self) -> Path: """Locate the config file. """ - if _in_argv('--config-file'): - config_file_index = _index_argv('--config-file') + config_file_index = _index_argv('--config-file') + + if config_file_index is not None: config_file_param = sys.argv[config_file_index] if '=' in config_file_param: @@ -327,13 +354,13 @@ def find_config_file(self): return Path(filedir, filename).resolve() - def argument(self, *args, **kwargs): + def argument(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: """Decorator to call self.add_argument or self..add_argument. """ if self._inside_context_manager: raise RuntimeError('You must run this before the with statement!') - def argument_function(handler): + def argument_function(handler: Callable[..., Any]) -> Callable[..., Any]: config_name = handler.__name__ subcommand_name = config_name.replace("_", "-") arg_name = get_argument_name(self._arg_parser, *args, **kwargs) @@ -384,7 +411,7 @@ def argument_function(handler): return argument_function - def parse_args(self): + def parse_args(self) -> None: """Parse the CLI args. """ if self.args: @@ -394,6 +421,7 @@ def parse_args(self): argcomplete.autocomplete(self._arg_parser) self.acquire_lock() + for key, value in vars(self._arg_parser.parse_args()).items(): self.args[key] = value @@ -402,7 +430,7 @@ def parse_args(self): self.release_lock() - def read_config_file(self): + def read_config_file(self) -> Tuple[Configuration, Configuration]: """Read in the configuration file and return Configuration objects for it and the config_source. """ config = Configuration() @@ -435,14 +463,14 @@ def read_config_file(self): return config, config_source - def initialize_config(self): + def initialize_config(self) -> None: """Read in the configuration file and store it in self.config. """ self.acquire_lock() self.config, self.config_source = self.read_config_file() self.release_lock() - def merge_args_into_config(self): + def merge_args_into_config(self) -> None: """Merge CLI arguments into self.config to create the runtime configuration. """ self.acquire_lock() @@ -472,7 +500,7 @@ def merge_args_into_config(self): self.release_lock() - def _save_config_file(self, config): + def _save_config_file(self, config: AttrDict) -> None: """Write config to disk. """ # Generate a sanitized version of our running configuration @@ -497,7 +525,7 @@ def _save_config_file(self, config): self.log.warning('Config file saving failed, not replacing %s with %s.', str(self.config_file), tmpfile.name) self.release_lock() - def write_config_option(self, section, option): + def write_config_option(self, section: str, option: Any) -> None: """Save a single config option to the config file. """ if not self.config_file: @@ -516,7 +544,7 @@ def write_config_option(self, section, option): # Housekeeping self.log.info('Wrote configuration to %s', shlex.quote(str(self.config_file))) - def save_config(self): + def save_config(self) -> None: """Save the current configuration to the config file. """ self.log.debug("Saving config file to '%s'", str(self.config_file)) @@ -529,7 +557,7 @@ def save_config(self): self._save_config_file(self.config) self.log.info('Wrote configuration to %s', shlex.quote(str(self.config_file))) - def check_deprecated(self): + def check_deprecated(self) -> None: entry_name = self._entrypoint.__name__ if entry_name in self._deprecated_commands: @@ -549,7 +577,7 @@ def check_deprecated(self): msg = self._deprecated_arguments[arg] self.log_deprecated_warning('Argument', arg, msg) - def __call__(self): + def __call__(self) -> Any: """Execute the entrypoint function. """ if not self._inside_context_manager: @@ -561,12 +589,13 @@ def __call__(self): if self._subcommand: return self._subcommand(self) - elif self._entrypoint: + + elif self._entrypoint is not None: return self._entrypoint(self) raise RuntimeError('No entrypoint provided!') - def entrypoint(self, description, deprecated=None): + def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Callable[..., Any]: """Decorator that marks the entrypoint used when a subcommand is not supplied. Args: description @@ -582,13 +611,12 @@ def entrypoint(self, description, deprecated=None): self.description = description self.release_lock() - def entrypoint_func(handler): + def entrypoint_func(handler: Callable[..., Any]) -> Callable[..., Any]: self.acquire_lock() if deprecated: - name = handler.__name__ - self._deprecated_commands[name] = deprecated - self.description += f' [Deprecated]: {deprecated}' + self._deprecated_commands[handler.__name__] = deprecated + self.description = f'{self.description} [Deprecated]: {deprecated}' self._entrypoint = handler self.release_lock() @@ -597,7 +625,14 @@ def entrypoint_func(handler): return entrypoint_func - def add_subcommand(self, handler, description, hidden=False, deprecated=None, **kwargs): + def add_subcommand( + self, + handler: Callable[..., Any], + description: str, + hidden: bool = False, + deprecated: Optional[str] = None, + **kwargs: Any, + ) -> Callable[..., Any]: """Register a subcommand. Args: @@ -628,9 +663,11 @@ def add_subcommand(self, handler, description, hidden=False, deprecated=None, ** description += f' [Deprecated]: {deprecated}' self.acquire_lock() - if not hidden: + + if not hidden and self._subparsers is not None: self._subparsers.metavar = "{%s,%s}" % (self._subparsers.metavar[1:-1], name) if self._subparsers.metavar else "{%s%s}" % (self._subparsers.metavar[1:-1], name) kwargs['help'] = description + self.subcommands[name] = SubparserWrapper(self, name, self._subparsers.add_parser(name, **kwargs)) self.subcommands[name].set_defaults(entrypoint=handler) @@ -638,7 +675,7 @@ def add_subcommand(self, handler, description, hidden=False, deprecated=None, ** return handler - def subcommand(self, description, hidden=False, **kwargs): + def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> Callable[..., Any]: """Decorator to register a subcommand. Args: @@ -649,12 +686,12 @@ def subcommand(self, description, hidden=False, **kwargs): hidden When True don't display this command in --help """ - def subcommand_function(handler): + def subcommand_function(handler: Callable[..., Any]) -> Callable[..., Any]: return self.add_subcommand(handler, description, hidden=hidden, **kwargs) return subcommand_function - def setup_logging(self): + def setup_logging(self) -> None: """Called by __enter__() to setup the logging configuration. """ if len(logging.root.handlers) != 0: @@ -688,7 +725,7 @@ def setup_logging(self): self.release_lock() - def __enter__(self): + def __enter__(self) -> Any: if self._inside_context_manager: self.log.debug('Warning: context manager was entered again. This usually means that self.__call__() was called before the with statement. You probably do not want to do that.') return @@ -708,7 +745,12 @@ def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.acquire_lock() self._inside_context_manager = False self.release_lock() @@ -718,12 +760,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): logging.exception(exc_val) exit(255) - def is_spinner(self, name): + def is_spinner(self, name: str) -> bool: """Returns true if name is a valid spinner. """ return name in Spinners.__members__ or name in self._spinners - def add_spinner(self, name, spinner): + def add_spinner(self, name: str, spinner: Dict[str, int | Sequence[str]]) -> None: """Adds a new spinner to the list of spinners. A spinner is a dictionary with two keys: @@ -745,7 +787,19 @@ def add_spinner(self, name, spinner): self._spinners[name] = spinner - def spinner(self, text, *args, spinner=None, animation='ellipsed', placement='left', color='blue', interval=-1, stream=sys.stdout, enabled=True, **kwargs): + def spinner( + self, + text: str, + *args: Any, + spinner: Optional[str] = None, + animation: str = 'ellipsed', + placement: str = 'left', + color: str = 'blue', + interval: int = -1, + stream: Any = sys.stdout, + enabled: bool = True, + **kwargs: Any, + ) -> Halo: """Create a spinner object for showing activity to the user. This uses halo behind the scenes, most of the arguments map to Halo objects 1:1. @@ -818,12 +872,15 @@ def long_running_function(): enabled Enable or disable the spinner. Defaults to `True`. """ - if isinstance(spinner, str) and spinner in self._spinners: - spinner = self._spinners[spinner] + spinner_name = spinner or 'line' # FIXME: Grab one of the ascii spinners at random instead of line + + if spinner in self._spinners: + spinner_name = '' + spinner_obj = self._spinners[spinner] return Halo( text=format_ansi(text % (args or kwargs)), - spinner=spinner if spinner else 'line', # FIXME: Grab one of the ascii spinners at random instead of line + spinner=spinner_name or spinner_obj, animation=None if animation == 'ellipsed' else animation, placement=placement, color=color, diff --git a/milc/questions.py b/milc/questions.py index 96db496..396374a 100644 --- a/milc/questions.py +++ b/milc/questions.py @@ -1,19 +1,20 @@ """Sometimes you need to ask the user a question. MILC provides basic functions for collecting and validating user input. You can find these in the `milc.questions` module. """ from getpass import getpass +from typing import Any, Callable, Optional, Sequence -import milc +from milc import cli from .ansi import format_ansi -def yesno(prompt, *args, default=None, **kwargs): +def yesno(prompt: str, *args: Any, default: Optional[bool] = None, **kwargs: Any) -> bool: """Displays `prompt` to the user and gets a yes or no response. Returns `True` for a yes and `False` for a no. | Argument | Description | |----------|-------------| - | prompt | The prompt to present to the user. Can include ANSI and format strings like milc's `cli.echo()`. | + | prompt | The prompt to present to the user. Can include ANSI and format strings like `cli.echo()`. | | default | Whether to default to a Yes or No when the user presses enter.

None- force the user to enter Y or N
True- Default to yes
False- Default to no | If you add `--yes` and `--no` arguments to your program the user can answer questions by passing command line flags. @@ -23,27 +24,35 @@ def yesno(prompt, *args, default=None, **kwargs): @cli.argument('-n', '--no', action='store_true', arg_only=True, help='Answer no to all questions.') ``` """ - if not args and kwargs: - args = kwargs + if args and kwargs: + raise ValueError("You can't pass both args and kwargs!") - if 'no' in milc.cli.args and milc.cli.args.no: + # Check if we should return an answer without asking + if 'no' in cli.args and cli.args.no: return False - if 'yes' in milc.cli.args and milc.cli.args.yes: + if 'yes' in cli.args and cli.args.yes: return True - if not milc.cli.interactive: + if not cli.interactive: return False + # Format the prompt + if args: + formatted_prompt = prompt % args + else: + formatted_prompt = prompt % kwargs + if default is None: - prompt = prompt + ' [y/n] ' + formatted_prompt = formatted_prompt + ' [y/n] ' elif default: - prompt = prompt + ' [Y/n] ' + formatted_prompt = formatted_prompt + ' [Y/n] ' else: - prompt = prompt + ' [y/N] ' + formatted_prompt = formatted_prompt + ' [y/N] ' + # Get input from the user while True: - answer = input(format_ansi(prompt % args)) + answer = input(format_ansi(formatted_prompt)) if not answer and default is not None: return default @@ -55,9 +64,19 @@ def yesno(prompt, *args, default=None, **kwargs): return False -def password(prompt='Enter password:', *args, confirm=False, confirm_prompt='Confirm password:', confirm_limit=3, validate=None, **kwargs): +def password( + prompt: str = 'Enter password:', + *args: Any, + confirm: bool = False, + confirm_prompt: str = 'Confirm password:', + confirm_limit: int = 3, + validate: Optional[Callable[[str], bool]] = None, + **kwargs: Any, +) -> Optional[str]: """Securely receive a password from the user. Returns the password or None. + When running in non-interactive mode this will always return None. Otherwise it will return the confirmed password the user provides. + | Argument | Description | |----------|-------------| | prompt | The prompt to present to the user. Can include ANSI and format strings like milc's `cli.echo()`. | @@ -66,21 +85,24 @@ def password(prompt='Enter password:', *args, confirm=False, confirm_prompt='Con | confirm_limit | Number of attempts to confirm before giving up. Default: 3 | | validate | This is an optional function that can be used to validate the password, EG to check complexity. It should return True or False and have the following signature:

`def function_name(answer):` | """ - if not milc.cli.interactive: + if not cli.interactive: return None - if not args and kwargs: - args = kwargs + if args: + formatted_prompt = prompt % args + else: + formatted_prompt = prompt % kwargs - if prompt[-1] != ' ': - prompt += ' ' + if formatted_prompt[-1] != ' ': + formatted_prompt += ' ' if confirm_prompt[-1] != ' ': confirm_prompt += ' ' i = 0 + while not confirm_limit or i < confirm_limit: - pw = getpass(format_ansi(prompt % args)) + pw = getpass(format_ansi(formatted_prompt)) if pw: if validate is not None and not validate(pw): @@ -90,15 +112,25 @@ def password(prompt='Enter password:', *args, confirm=False, confirm_prompt='Con if getpass(format_ansi(confirm_prompt % args)) == pw: return pw else: - milc.cli.log.error('Passwords do not match!') + cli.log.error('Passwords do not match!') else: return pw i += 1 + return None + -def question(prompt, *args, default=None, confirm=False, answer_type=str, validate=None, **kwargs): +def question( + prompt: str, + *args: Any, + default: Optional[str] = None, + confirm: bool = False, + answer_type: Callable[[str], str] = str, + validate: Optional[Callable[..., bool]] = None, + **kwargs: Any, +) -> Optional[str]: """Allow the user to type in a free-form string to answer. | Argument | Description | @@ -109,7 +141,7 @@ def question(prompt, *args, default=None, confirm=False, answer_type=str, valida | answer_type | Specify a type function for the answer. Will re-prompt the user if the function raises any errors. Common choices here include int, float, and decimal.Decimal. | | validate | This is an optional function that can be used to validate the answer. It should return True or False and have the following signature:

`def function_name(answer, *args, **kwargs):` | """ - if not milc.cli.interactive: + if not cli.interactive: return default if default is not None: @@ -129,19 +161,28 @@ def question(prompt, *args, default=None, confirm=False, answer_type=str, valida try: return answer_type(answer) except Exception as e: - milc.cli.log.error('Could not convert answer (%s) to type %s: %s', answer, answer_type.__name__, str(e)) + cli.log.error('Could not convert answer (%s) to type %s: %s', answer, answer_type.__name__, str(e)) + return None else: try: return answer_type(answer) except Exception as e: - milc.cli.log.error('Could not convert answer (%s) to type %s: %s', answer, answer_type.__name__, str(e)) + cli.log.error('Could not convert answer (%s) to type %s: %s', answer, answer_type.__name__, str(e)) elif default is not None: return default -def choice(heading, options, *args, default=None, confirm=False, prompt='Please enter your choice: ', **kwargs): +def choice( + heading: str, + options: Sequence[str], + *args: Any, + default: Optional[int] = None, + confirm: bool = False, + prompt: str = 'Please enter your choice: ', + **kwargs: Any, +) -> Optional[str]: """Present the user with a list of options and let them select one. Users can enter either the number or the text of their choice. This will return the value of the item they choose, not the numerical index. @@ -159,22 +200,28 @@ def choice(heading, options, *args, default=None, confirm=False, prompt='Please !!! warning This will return the value of the item they choose, not the numerical index. """ - if not args and kwargs: - args = kwargs + if args: + formatted_heading = heading % args + else: + formatted_heading = heading % kwargs - if not milc.cli.interactive: - return default + if not cli.interactive: + if default is None: + return None + return options[default] - if prompt and default is not None: - prompt = prompt + ' [%s] ' % (default + 1,) - elif prompt[-1] != ' ': + if prompt[-1] != ' ': prompt += ' ' + if default is not None: + prompt = '%s[%s] ' % (prompt, default + 1) + while True: # Prompt for an answer. - milc.cli.echo(heading % args) + cli.echo(formatted_heading) + for i, option in enumerate(options, 1): - milc.cli.echo('\t{fg_cyan}%d.{fg_reset} %s', i, option) + cli.echo('\t{fg_cyan}%d.{fg_reset} %s', i, option) answer = input(format_ansi(prompt)) @@ -184,24 +231,21 @@ def choice(heading, options, *args, default=None, confirm=False, prompt='Please # Massage the answer into a valid integer if answer == '' and default is not None: - answer = default + answer_index = default + elif answer.isnumeric(): + answer_index = int(answer) - 1 else: - try: - answer = int(answer) - 1 - except Exception as e: - milc.cli.log.error('Invalid choice: %s', answer) - milc.cli.log.debug('Could not convert %s to int: %s: %s', answer, e.__class__.__name__, e) - if milc.cli.config.general.verbose: - milc.cli.log.exception(e) - continue + cli.log.error('Invalid choice: %s', answer) + cli.log.debug('Could not convert %s to int', answer) + continue # Validate the answer - if answer >= len(options) or answer < 0: - milc.cli.log.error('Invalid choice: %s', answer + 1) + if answer_index >= len(options) or answer_index < 0: + cli.log.error('Invalid choice: %s', answer_index + 1) continue - if confirm and not yesno('Is the answer "%s" correct?', answer + 1, default=True): + if confirm and not yesno('Is the answer "%s" correct?', answer_index + 1, default=True): continue # Return the answer they chose. - return options[answer] + return options[answer_index] diff --git a/milc/subcommand/config.py b/milc/subcommand/config.py index cf573ea..b2c32a8 100644 --- a/milc/subcommand/config.py +++ b/milc/subcommand/config.py @@ -1,9 +1,12 @@ """Read and write configuration settings """ +from typing import Any, Tuple + import milc +from milc.milc import MILC -def print_config(section, key): +def print_config(section: str, key: str) -> None: """Print a single config setting to stdout. """ if milc.cli.config_source[section][key] == 'config_file': @@ -12,7 +15,7 @@ def print_config(section, key): milc.cli.echo('{fg_cyan}%s.%s=%s', section, key, milc.cli.config[section][key]) -def show_config(): +def show_config() -> None: """Print the current configuration to stdout. """ for section in sorted(milc.cli.config): @@ -21,10 +24,10 @@ def show_config(): print_config(section, key) -def parse_config_token(config_token): +def parse_config_token(config_token: str) -> Tuple[str, str, Any]: """Split a user-supplied configuration-token into its components. """ - section = option = value = None + section = option = value = '' if '=' in config_token and '.' not in config_token: milc.cli.log.error('Invalid configuration token, the key must be of the form
.