conservancy_beancount/tests/test_cliutil.py

229 lines
7 KiB
Python

"""Test CLI utilities"""
# Copyright © 2020 Brett Smith
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
#
# Full copyright and licensing details can be found at toplevel file
# LICENSE.txt in the repository.
import argparse
import datetime
import errno
import io
import inspect
import logging
import os
import re
import sys
import traceback
import pytest
from pathlib import Path
from conservancy_beancount import cliutil
FILE_NAMES = ['-foobar', '-foo.bin']
STREAM_PATHS = [None, Path('-')]
class MockTraceback:
def __init__(self, stack=None, index=0):
if stack is None:
stack = inspect.stack(context=False)
self._stack = stack
self._index = index
frame_record = self._stack[self._index]
self.tb_frame = frame_record.frame
self.tb_lineno = frame_record.lineno
@property
def tb_next(self):
try:
return type(self)(self._stack, self._index + 1)
except IndexError:
return None
@pytest.fixture(scope='module')
def argparser():
parser = argparse.ArgumentParser(prog='test_cliutil')
cliutil.add_loglevel_argument(parser)
cliutil.add_version_argument(parser)
return parser
@pytest.mark.parametrize('path_name', FILE_NAMES)
def test_bytes_output_path(path_name, tmp_path):
path = tmp_path / path_name
stream = io.BytesIO()
actual = cliutil.bytes_output(path, stream)
assert actual is not stream
assert str(actual.name) == str(path)
assert 'w' in actual.mode
assert 'b' in actual.mode
@pytest.mark.parametrize('path', STREAM_PATHS)
def test_bytes_output_stream(path):
stream = io.BytesIO()
actual = cliutil.bytes_output(path, stream)
assert actual is stream
@pytest.mark.parametrize('year,month,day', [
(2000, 1, 1),
(2016, 2, 29),
(2020, 12, 31),
])
def test_date_arg_valid(year, month, day):
expected = datetime.date(year, month, day)
assert cliutil.date_arg(expected.isoformat()) == expected
@pytest.mark.parametrize('arg', [
'2000',
'20-02-12',
'2019-02-29',
'two thousand',
])
def test_date_arg_invalid(arg):
with pytest.raises(ValueError):
cliutil.date_arg(arg)
@pytest.mark.parametrize('year', [
1990,
2000,
2009,
])
def test_year_or_date_arg_year(year):
assert cliutil.year_or_date_arg(str(year)) == year
@pytest.mark.parametrize('year,month,day', [
(2000, 1, 1),
(2016, 2, 29),
(2020, 12, 31),
])
def test_year_or_date_arg_date(year, month, day):
expected = datetime.date(year, month, day)
assert cliutil.year_or_date_arg(expected.isoformat()) == expected
@pytest.mark.parametrize('arg', [
'-1',
str(sys.maxsize),
'MMDVIII',
'2019-02-29',
])
def test_year_or_date_arg_invalid(arg):
with pytest.raises(ValueError):
cliutil.year_or_date_arg(arg)
@pytest.mark.parametrize('func_name', [
'bytes_output',
'text_output',
])
def test_default_output(func_name):
actual = getattr(cliutil, func_name)()
assert actual.fileno() == sys.stdout.fileno()
@pytest.mark.parametrize('path_name', FILE_NAMES)
def test_text_output_path(path_name, tmp_path):
path = tmp_path / path_name
stream = io.StringIO()
actual = cliutil.text_output(path, stream)
assert actual is not stream
assert str(actual.name) == str(path)
assert 'w' in actual.mode
@pytest.mark.parametrize('path', STREAM_PATHS)
def test_text_output_stream(path):
stream = io.StringIO()
actual = cliutil.text_output(path, stream)
assert actual is stream
@pytest.mark.parametrize('errnum', [
errno.EACCES,
errno.EPERM,
errno.ENOENT,
])
def test_excepthook_oserror(errnum, caplog):
error = OSError(errnum, os.strerror(errnum), 'TestFilename')
with pytest.raises(SystemExit) as exc_check:
cliutil.ExceptHook()(type(error), error, None)
assert exc_check.value.args[0] == os.EX_IOERR
assert caplog.records
for log in caplog.records:
assert log.levelname == 'CRITICAL'
assert log.message == f"I/O error: {error.filename}: {error.strerror}"
@pytest.mark.parametrize('exc_type', [
AttributeError,
RuntimeError,
ValueError,
])
def test_excepthook_bug(exc_type, caplog):
error = exc_type("test message")
with pytest.raises(SystemExit) as exc_check:
cliutil.ExceptHook()(exc_type, error, None)
assert exc_check.value.args[0] == os.EX_SOFTWARE
assert caplog.records
for log in caplog.records:
assert log.levelname == 'CRITICAL'
assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
def test_excepthook_traceback(caplog):
error = KeyError('test')
args = (type(error), error, MockTraceback())
caplog.set_level(logging.DEBUG)
with pytest.raises(SystemExit) as exc_check:
cliutil.ExceptHook()(*args)
assert caplog.records
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
@pytest.mark.parametrize('arg,expected', [
('debug', logging.DEBUG),
('info', logging.INFO),
('warning', logging.WARNING),
('warn', logging.WARNING),
('error', logging.ERROR),
('err', logging.ERROR),
('critical', logging.CRITICAL),
('crit', logging.CRITICAL),
])
def test_loglevel_argument(argparser, arg, expected):
for method in ['lower', 'title', 'upper']:
args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
assert args.loglevel is expected
def test_setup_logger():
stream = io.StringIO()
logger = cliutil.setup_logger(
'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
)
logger.critical("test crit")
assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
@pytest.mark.parametrize('arg', [
'--license',
'--version',
'--copyright',
])
def test_version_argument(argparser, capsys, arg):
with pytest.raises(SystemExit) as exc_check:
args = argparser.parse_args(['--version'])
assert exc_check.value.args[0] == 0
stdout, _ = capsys.readouterr()
lines = iter(stdout.splitlines())
assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, "<EOF>"))
@pytest.mark.parametrize('date,diff,expected', [
(datetime.date(2010, 2, 28), 0, datetime.date(2010, 2, 28)),
(datetime.date(2010, 2, 28), 1, datetime.date(2011, 2, 28)),
(datetime.date(2010, 2, 28), 2, datetime.date(2012, 2, 28)),
(datetime.date(2010, 2, 28), -1, datetime.date(2009, 2, 28)),
(datetime.date(2010, 2, 28), -2, datetime.date(2008, 2, 28)),
(datetime.date(2012, 2, 29), 2, datetime.date(2014, 3, 1)),
(datetime.date(2012, 2, 29), 4, datetime.date(2016, 2, 29)),
(datetime.date(2012, 2, 29), -2, datetime.date(2010, 2, 28)),
(datetime.date(2012, 2, 29), -4, datetime.date(2008, 2, 29)),
(datetime.date(2010, 3, 1), 1, datetime.date(2011, 3, 1)),
(datetime.date(2010, 3, 1), 2, datetime.date(2012, 3, 1)),
(datetime.date(2010, 3, 1), -1, datetime.date(2009, 3, 1)),
(datetime.date(2010, 3, 1), -2, datetime.date(2008, 3, 1)),
])
def test_diff_year(date, diff, expected):
assert cliutil.diff_year(date, diff) == expected