diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py
new file mode 100644
index 0000000..58340e7
--- /dev/null
+++ b/conservancy_beancount/cliutil.py
@@ -0,0 +1,148 @@
+"""cliutil - Utilities for CLI tools"""
+PKGNAME = 'conservancy_beancount'
+LICENSE = """
+Copyright © 2020 Brett Smith
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see ."""
+
+import argparse
+import enum
+import logging
+import operator
+import os
+import pkg_resources
+import signal
+import sys
+import traceback
+import types
+
+from typing import (
+ Any,
+ Iterable,
+ NoReturn,
+ Optional,
+ Sequence,
+ TextIO,
+ Type,
+ Union,
+)
+
+VERSION = pkg_resources.require(PKGNAME)[0].version
+
+class ExceptHook:
+ def __init__(self,
+ logger: Optional[logging.Logger]=None,
+ default_exitcode: int=3,
+ ) -> None:
+ if logger is None:
+ logger = logging.getLogger()
+ self.logger = logger
+ self.default_exitcode = default_exitcode
+
+ def __call__(self,
+ exc_type: Type[BaseException],
+ exc_value: BaseException,
+ exc_tb: types.TracebackType,
+ ) -> NoReturn:
+ exitcode = self.default_exitcode
+ if isinstance(exc_value, KeyboardInterrupt):
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ os.kill(0, signal.SIGINT)
+ signal.pause()
+ elif isinstance(exc_value, OSError):
+ exitcode += 1
+ msg = "I/O error: {e.filename}: {e.strerror}".format(e=exc_value)
+ else:
+ parts = [type(exc_value).__name__, *exc_value.args]
+ msg = "internal " + ": ".join(parts)
+ self.logger.critical(msg)
+ self.logger.debug(
+ ''.join(traceback.format_exception(exc_type, exc_value, exc_tb)),
+ )
+ raise SystemExit(exitcode)
+
+
+class InfoAction(argparse.Action):
+ def __call__(self,
+ parser: argparse.ArgumentParser,
+ namespace: argparse.Namespace,
+ values: Union[Sequence[Any], str, None]=None,
+ option_string: Optional[str]=None,
+ ) -> NoReturn:
+ if isinstance(self.const, str):
+ info = self.const
+ exitcode = 0
+ else:
+ info, exitcode = self.const
+ print(info)
+ raise SystemExit(exitcode)
+
+
+class LogLevel(enum.IntEnum):
+ DEBUG = logging.DEBUG
+ INFO = logging.INFO
+ WARNING = logging.WARNING
+ ERROR = logging.ERROR
+ CRITICAL = logging.CRITICAL
+ WARN = WARNING
+ ERR = ERROR
+ CRIT = CRITICAL
+
+ @classmethod
+ def from_arg(cls, arg: str) -> int:
+ try:
+ return cls[arg.upper()].value
+ except KeyError:
+ raise ValueError(f"unknown loglevel {arg!r}") from None
+
+ @classmethod
+ def choices(cls) -> Iterable[str]:
+ for level in sorted(cls, key=operator.attrgetter('value')):
+ yield level.name.lower()
+
+def add_loglevel_argument(parser: argparse.ArgumentParser,
+ default: LogLevel=LogLevel.INFO) -> argparse.Action:
+ return parser.add_argument(
+ '--loglevel',
+ metavar='LEVEL',
+ default=default.value,
+ type=LogLevel.from_arg,
+ help="Show logs at this level and above."
+ f" Specify one of {', '.join(LogLevel.choices())}."
+ f" Default {default.name.lower()}.",
+ )
+
+def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
+ progname = parser.prog or sys.argv[0]
+ return parser.add_argument(
+ '--version', '--copyright', '--license',
+ action=InfoAction,
+ nargs=0,
+ const=f"{progname} version {VERSION}\n{LICENSE}",
+ help="Show program version and license information",
+ )
+
+def setup_logger(logger: Union[str, logging.Logger]='',
+ loglevel: int=logging.INFO,
+ stream: TextIO=sys.stderr,
+ fmt: str='%(name)s: %(levelname)s: %(message)s',
+) -> logging.Logger:
+ if isinstance(logger, str):
+ logger = logging.getLogger(logger)
+ formatter = logging.Formatter(fmt)
+ handler = logging.StreamHandler(stream)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ logger.setLevel(loglevel)
+ return logger
diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py
new file mode 100644
index 0000000..e8a3867
--- /dev/null
+++ b/tests/test_cliutil.py
@@ -0,0 +1,129 @@
+"""Test CLI utilities"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import argparse
+import errno
+import io
+import inspect
+import logging
+import os
+import re
+import traceback
+
+import pytest
+
+from conservancy_beancount import cliutil
+
+class MockTraceback:
+ def __init__(self, stack=None, index=0):
+ if stack is None:
+ stack = inspect.stack(context=False)
+ self._stack = stack
+ self._index = index
+ frame_record = self._stack[self._index]
+ self.tb_frame = frame_record.frame
+ self.tb_lineno = frame_record.lineno
+
+ @property
+ def tb_next(self):
+ try:
+ return type(self)(self._stack, self._index + 1)
+ except IndexError:
+ return None
+
+
+@pytest.fixture(scope='module')
+def argparser():
+ parser = argparse.ArgumentParser(prog='test_cliutil')
+ cliutil.add_loglevel_argument(parser)
+ cliutil.add_version_argument(parser)
+ return parser
+
+@pytest.mark.parametrize('errnum', [
+ errno.EACCES,
+ errno.EPERM,
+ errno.ENOENT,
+])
+def test_excepthook_oserror(errnum, caplog):
+ error = OSError(errnum, os.strerror(errnum), 'TestFilename')
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(type(error), error, None)
+ assert exc_check.value.args[0] == 4
+ assert caplog.records
+ for log in caplog.records:
+ assert log.levelname == 'CRITICAL'
+ assert log.message == f"I/O error: {error.filename}: {error.strerror}"
+
+@pytest.mark.parametrize('exc_type', [
+ AttributeError,
+ RuntimeError,
+ ValueError,
+])
+def test_excepthook_bug(exc_type, caplog):
+ error = exc_type("test message")
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(exc_type, error, None)
+ assert exc_check.value.args[0] == 3
+ assert caplog.records
+ for log in caplog.records:
+ assert log.levelname == 'CRITICAL'
+ assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
+
+def test_excepthook_traceback(caplog):
+ error = KeyError('test')
+ args = (type(error), error, MockTraceback())
+ caplog.set_level(logging.DEBUG)
+ with pytest.raises(SystemExit) as exc_check:
+ cliutil.ExceptHook()(*args)
+ assert caplog.records
+ assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
+
+@pytest.mark.parametrize('arg,expected', [
+ ('debug', logging.DEBUG),
+ ('info', logging.INFO),
+ ('warning', logging.WARNING),
+ ('warn', logging.WARNING),
+ ('error', logging.ERROR),
+ ('err', logging.ERROR),
+ ('critical', logging.CRITICAL),
+ ('crit', logging.CRITICAL),
+])
+def test_loglevel_argument(argparser, arg, expected):
+ for method in ['lower', 'title', 'upper']:
+ args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
+ assert args.loglevel is expected
+
+def test_setup_logger():
+ stream = io.StringIO()
+ logger = cliutil.setup_logger(
+ 'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s',
+ )
+ logger.debug("test debug")
+ logger.info("test info")
+ assert stream.getvalue() == "test_cliutil INFO: test info\n"
+
+@pytest.mark.parametrize('arg', [
+ '--license',
+ '--version',
+ '--copyright',
+])
+def test_version_argument(argparser, capsys, arg):
+ with pytest.raises(SystemExit) as exc_check:
+ args = argparser.parse_args(['--version'])
+ assert exc_check.value.args[0] == 0
+ stdout, _ = capsys.readouterr()
+ lines = iter(stdout.splitlines())
+ assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, ""))