diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py new file mode 100644 index 0000000..58340e7 --- /dev/null +++ b/conservancy_beancount/cliutil.py @@ -0,0 +1,148 @@ +"""cliutil - Utilities for CLI tools""" +PKGNAME = 'conservancy_beancount' +LICENSE = """ +Copyright © 2020 Brett Smith + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see .""" + +import argparse +import enum +import logging +import operator +import os +import pkg_resources +import signal +import sys +import traceback +import types + +from typing import ( + Any, + Iterable, + NoReturn, + Optional, + Sequence, + TextIO, + Type, + Union, +) + +VERSION = pkg_resources.require(PKGNAME)[0].version + +class ExceptHook: + def __init__(self, + logger: Optional[logging.Logger]=None, + default_exitcode: int=3, + ) -> None: + if logger is None: + logger = logging.getLogger() + self.logger = logger + self.default_exitcode = default_exitcode + + def __call__(self, + exc_type: Type[BaseException], + exc_value: BaseException, + exc_tb: types.TracebackType, + ) -> NoReturn: + exitcode = self.default_exitcode + if isinstance(exc_value, KeyboardInterrupt): + signal.signal(signal.SIGINT, signal.SIG_DFL) + os.kill(0, signal.SIGINT) + signal.pause() + elif isinstance(exc_value, OSError): + exitcode += 1 + msg = "I/O error: {e.filename}: {e.strerror}".format(e=exc_value) + else: + parts = [type(exc_value).__name__, *exc_value.args] + msg = "internal " + ": ".join(parts) + self.logger.critical(msg) + self.logger.debug( + ''.join(traceback.format_exception(exc_type, exc_value, exc_tb)), + ) + raise SystemExit(exitcode) + + +class InfoAction(argparse.Action): + def __call__(self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Union[Sequence[Any], str, None]=None, + option_string: Optional[str]=None, + ) -> NoReturn: + if isinstance(self.const, str): + info = self.const + exitcode = 0 + else: + info, exitcode = self.const + print(info) + raise SystemExit(exitcode) + + +class LogLevel(enum.IntEnum): + DEBUG = logging.DEBUG + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + CRITICAL = logging.CRITICAL + WARN = WARNING + ERR = ERROR + CRIT = CRITICAL + + @classmethod + def from_arg(cls, arg: str) -> int: + try: + return cls[arg.upper()].value + except KeyError: + raise ValueError(f"unknown loglevel {arg!r}") from None + + @classmethod + def choices(cls) -> Iterable[str]: + for level in sorted(cls, key=operator.attrgetter('value')): + yield level.name.lower() + +def add_loglevel_argument(parser: argparse.ArgumentParser, + default: LogLevel=LogLevel.INFO) -> argparse.Action: + return parser.add_argument( + '--loglevel', + metavar='LEVEL', + default=default.value, + type=LogLevel.from_arg, + help="Show logs at this level and above." + f" Specify one of {', '.join(LogLevel.choices())}." + f" Default {default.name.lower()}.", + ) + +def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action: + progname = parser.prog or sys.argv[0] + return parser.add_argument( + '--version', '--copyright', '--license', + action=InfoAction, + nargs=0, + const=f"{progname} version {VERSION}\n{LICENSE}", + help="Show program version and license information", + ) + +def setup_logger(logger: Union[str, logging.Logger]='', + loglevel: int=logging.INFO, + stream: TextIO=sys.stderr, + fmt: str='%(name)s: %(levelname)s: %(message)s', +) -> logging.Logger: + if isinstance(logger, str): + logger = logging.getLogger(logger) + formatter = logging.Formatter(fmt) + handler = logging.StreamHandler(stream) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(loglevel) + return logger diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py new file mode 100644 index 0000000..e8a3867 --- /dev/null +++ b/tests/test_cliutil.py @@ -0,0 +1,129 @@ +"""Test CLI utilities""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import argparse +import errno +import io +import inspect +import logging +import os +import re +import traceback + +import pytest + +from conservancy_beancount import cliutil + +class MockTraceback: + def __init__(self, stack=None, index=0): + if stack is None: + stack = inspect.stack(context=False) + self._stack = stack + self._index = index + frame_record = self._stack[self._index] + self.tb_frame = frame_record.frame + self.tb_lineno = frame_record.lineno + + @property + def tb_next(self): + try: + return type(self)(self._stack, self._index + 1) + except IndexError: + return None + + +@pytest.fixture(scope='module') +def argparser(): + parser = argparse.ArgumentParser(prog='test_cliutil') + cliutil.add_loglevel_argument(parser) + cliutil.add_version_argument(parser) + return parser + +@pytest.mark.parametrize('errnum', [ + errno.EACCES, + errno.EPERM, + errno.ENOENT, +]) +def test_excepthook_oserror(errnum, caplog): + error = OSError(errnum, os.strerror(errnum), 'TestFilename') + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(type(error), error, None) + assert exc_check.value.args[0] == 4 + assert caplog.records + for log in caplog.records: + assert log.levelname == 'CRITICAL' + assert log.message == f"I/O error: {error.filename}: {error.strerror}" + +@pytest.mark.parametrize('exc_type', [ + AttributeError, + RuntimeError, + ValueError, +]) +def test_excepthook_bug(exc_type, caplog): + error = exc_type("test message") + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(exc_type, error, None) + assert exc_check.value.args[0] == 3 + assert caplog.records + for log in caplog.records: + assert log.levelname == 'CRITICAL' + assert log.message == f"internal {exc_type.__name__}: {error.args[0]}" + +def test_excepthook_traceback(caplog): + error = KeyError('test') + args = (type(error), error, MockTraceback()) + caplog.set_level(logging.DEBUG) + with pytest.raises(SystemExit) as exc_check: + cliutil.ExceptHook()(*args) + assert caplog.records + assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) + +@pytest.mark.parametrize('arg,expected', [ + ('debug', logging.DEBUG), + ('info', logging.INFO), + ('warning', logging.WARNING), + ('warn', logging.WARNING), + ('error', logging.ERROR), + ('err', logging.ERROR), + ('critical', logging.CRITICAL), + ('crit', logging.CRITICAL), +]) +def test_loglevel_argument(argparser, arg, expected): + for method in ['lower', 'title', 'upper']: + args = argparser.parse_args(['--loglevel', getattr(arg, method)()]) + assert args.loglevel is expected + +def test_setup_logger(): + stream = io.StringIO() + logger = cliutil.setup_logger( + 'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s', + ) + logger.debug("test debug") + logger.info("test info") + assert stream.getvalue() == "test_cliutil INFO: test info\n" + +@pytest.mark.parametrize('arg', [ + '--license', + '--version', + '--copyright', +]) +def test_version_argument(argparser, capsys, arg): + with pytest.raises(SystemExit) as exc_check: + args = argparser.parse_args(['--version']) + assert exc_check.value.args[0] == 0 + stdout, _ = capsys.readouterr() + lines = iter(stdout.splitlines()) + assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, ""))