0a34ed6798
Now that we have richer exceptions, this is the easiest way to refactor out rewrite rule error handling from the various main functions where it currenly lives.
298 lines
9.3 KiB
Python
298 lines
9.3 KiB
Python
"""Test CLI utilities"""
|
|
# Copyright © 2020, 2021 Brett Smith
|
|
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
|
|
#
|
|
# Full copyright and licensing details can be found at toplevel file
|
|
# LICENSE.txt in the repository.
|
|
|
|
import argparse
|
|
import datetime
|
|
import enum
|
|
import errno
|
|
import io
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import traceback
|
|
|
|
import pytest
|
|
|
|
from pathlib import Path
|
|
|
|
from . import testutil
|
|
|
|
from conservancy_beancount import cliutil, errors
|
|
|
|
FILE_NAMES = ['-foobar', '-foo.bin']
|
|
STREAM_PATHS = [None, Path('-')]
|
|
|
|
class ArgChoices(enum.Enum):
|
|
AA = 'aa'
|
|
AB = 'ab'
|
|
BB = 'bb'
|
|
|
|
|
|
class MockTraceback:
|
|
def __init__(self, stack=None, index=0):
|
|
if stack is None:
|
|
stack = inspect.stack(context=False)
|
|
self._stack = stack
|
|
self._index = index
|
|
frame_record = self._stack[self._index]
|
|
self.tb_frame = frame_record.frame
|
|
self.tb_lineno = frame_record.lineno
|
|
|
|
@property
|
|
def tb_next(self):
|
|
try:
|
|
return type(self)(self._stack, self._index + 1)
|
|
except IndexError:
|
|
return None
|
|
|
|
|
|
@pytest.fixture(scope='module')
|
|
def arg_enum():
|
|
return cliutil.EnumArgument(ArgChoices)
|
|
|
|
@pytest.fixture(scope='module')
|
|
def argparser():
|
|
parser = argparse.ArgumentParser(prog='test_cliutil')
|
|
cliutil.add_loglevel_argument(parser)
|
|
cliutil.add_version_argument(parser)
|
|
return parser
|
|
|
|
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
|
def test_bytes_output_path(path_name, tmp_path):
|
|
path = tmp_path / path_name
|
|
stream = io.BytesIO()
|
|
actual = cliutil.bytes_output(path, stream)
|
|
assert actual is not stream
|
|
assert str(actual.name) == str(path)
|
|
assert 'w' in actual.mode
|
|
assert 'b' in actual.mode
|
|
|
|
@pytest.mark.parametrize('path', STREAM_PATHS)
|
|
def test_bytes_output_stream(path):
|
|
stream = io.BytesIO()
|
|
actual = cliutil.bytes_output(path, stream)
|
|
assert actual is stream
|
|
|
|
@pytest.mark.parametrize('year,month,day', [
|
|
(2000, 1, 1),
|
|
(2016, 2, 29),
|
|
(2020, 12, 31),
|
|
])
|
|
def test_date_arg_valid(year, month, day):
|
|
expected = datetime.date(year, month, day)
|
|
assert cliutil.date_arg(expected.isoformat()) == expected
|
|
|
|
@pytest.mark.parametrize('arg', [
|
|
'2000',
|
|
'20-02-12',
|
|
'2019-02-29',
|
|
'two thousand',
|
|
])
|
|
def test_date_arg_invalid(arg):
|
|
with pytest.raises(ValueError):
|
|
cliutil.date_arg(arg)
|
|
|
|
@pytest.mark.parametrize('year', [
|
|
1990,
|
|
2000,
|
|
2009,
|
|
])
|
|
def test_year_or_date_arg_year(year):
|
|
assert cliutil.year_or_date_arg(str(year)) == year
|
|
|
|
@pytest.mark.parametrize('year,month,day', [
|
|
(2000, 1, 1),
|
|
(2016, 2, 29),
|
|
(2020, 12, 31),
|
|
])
|
|
def test_year_or_date_arg_date(year, month, day):
|
|
expected = datetime.date(year, month, day)
|
|
assert cliutil.year_or_date_arg(expected.isoformat()) == expected
|
|
|
|
@pytest.mark.parametrize('arg', [
|
|
'-1',
|
|
str(sys.maxsize),
|
|
'MMDVIII',
|
|
'2019-02-29',
|
|
])
|
|
def test_year_or_date_arg_invalid(arg):
|
|
with pytest.raises(ValueError):
|
|
cliutil.year_or_date_arg(arg)
|
|
|
|
@pytest.mark.parametrize('func_name', [
|
|
'bytes_output',
|
|
'text_output',
|
|
])
|
|
def test_default_output(func_name):
|
|
actual = getattr(cliutil, func_name)()
|
|
assert actual.fileno() == sys.stdout.fileno()
|
|
|
|
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
|
def test_text_output_path(path_name, tmp_path):
|
|
path = tmp_path / path_name
|
|
stream = io.StringIO()
|
|
actual = cliutil.text_output(path, stream)
|
|
assert actual is not stream
|
|
assert str(actual.name) == str(path)
|
|
assert 'w' in actual.mode
|
|
|
|
@pytest.mark.parametrize('path', STREAM_PATHS)
|
|
def test_text_output_stream(path):
|
|
stream = io.StringIO()
|
|
actual = cliutil.text_output(path, stream)
|
|
assert actual is stream
|
|
|
|
@pytest.mark.parametrize('errnum', [
|
|
errno.EACCES,
|
|
errno.EPERM,
|
|
errno.ENOENT,
|
|
])
|
|
def test_excepthook_oserror(errnum, caplog):
|
|
error = OSError(errnum, os.strerror(errnum), 'TestFilename')
|
|
with pytest.raises(SystemExit) as exc_check:
|
|
cliutil.ExceptHook()(type(error), error, None)
|
|
assert exc_check.value.args[0] == os.EX_IOERR
|
|
assert caplog.records
|
|
for log in caplog.records:
|
|
assert log.levelname == 'CRITICAL'
|
|
assert log.message == f"I/O error: {error.filename}: {error.strerror}"
|
|
|
|
@pytest.mark.parametrize('errcls', [
|
|
errors.RewriteRuleActionError,
|
|
errors.RewriteRuleConditionError,
|
|
errors.RewriteRuleLoadError,
|
|
errors.RewriteRuleValidationError,
|
|
])
|
|
def test_excepthook_rewrite_rule_error(errcls, caplog):
|
|
name = errcls.__name__
|
|
error = errcls("bad rewrite rule", f"{name}.yml", 170, [name])
|
|
with pytest.raises(SystemExit) as exc_check:
|
|
cliutil.ExceptHook()(type(error), error, None)
|
|
assert exc_check.value.args[0] == cliutil.ExitCode.RewriteRulesError
|
|
assert caplog.records
|
|
for log in caplog.records:
|
|
assert log.levelname == 'CRITICAL'
|
|
lines = log.message.splitlines()
|
|
assert lines[0].startswith(f"{name}: bad rewrite rule in {name}.yml rule #170")
|
|
assert re.match(rf' source:\W+{name}\b', lines[-1])
|
|
|
|
@pytest.mark.parametrize('exc_type', [
|
|
AttributeError,
|
|
RuntimeError,
|
|
ValueError,
|
|
])
|
|
def test_excepthook_bug(exc_type, caplog):
|
|
error = exc_type("test message")
|
|
with pytest.raises(SystemExit) as exc_check:
|
|
cliutil.ExceptHook()(exc_type, error, None)
|
|
assert exc_check.value.args[0] == os.EX_SOFTWARE
|
|
assert caplog.records
|
|
for log in caplog.records:
|
|
assert log.levelname == 'CRITICAL'
|
|
assert log.message == f"internal {exc_type.__name__}: {error.args[0]}"
|
|
|
|
def test_excepthook_traceback(caplog):
|
|
error = KeyError('test')
|
|
args = (type(error), error, MockTraceback())
|
|
caplog.set_level(logging.DEBUG)
|
|
with pytest.raises(SystemExit) as exc_check:
|
|
cliutil.ExceptHook()(*args)
|
|
assert caplog.records
|
|
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
|
|
|
|
@pytest.mark.parametrize('arg,expected', [
|
|
('debug', logging.DEBUG),
|
|
('info', logging.INFO),
|
|
('warning', logging.WARNING),
|
|
('warn', logging.WARNING),
|
|
('error', logging.ERROR),
|
|
('err', logging.ERROR),
|
|
('critical', logging.CRITICAL),
|
|
('crit', logging.CRITICAL),
|
|
])
|
|
def test_loglevel_argument(argparser, arg, expected):
|
|
for method in ['lower', 'title', 'upper']:
|
|
args = argparser.parse_args(['--loglevel', getattr(arg, method)()])
|
|
assert args.loglevel is expected
|
|
|
|
def test_setup_logger():
|
|
stream = io.StringIO()
|
|
logger = cliutil.setup_logger(
|
|
'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
|
|
)
|
|
logger.critical("test crit")
|
|
assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
|
|
|
|
@pytest.mark.parametrize('arg', [
|
|
'--license',
|
|
'--version',
|
|
'--copyright',
|
|
])
|
|
def test_version_argument(argparser, capsys, arg):
|
|
with pytest.raises(SystemExit) as exc_check:
|
|
args = argparser.parse_args(['--version'])
|
|
assert exc_check.value.args[0] == 0
|
|
stdout, _ = capsys.readouterr()
|
|
lines = iter(stdout.splitlines())
|
|
assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, "<EOF>"))
|
|
|
|
@pytest.mark.parametrize('date,diff,expected', [
|
|
(datetime.date(2010, 2, 28), 0, datetime.date(2010, 2, 28)),
|
|
(datetime.date(2010, 2, 28), 1, datetime.date(2011, 2, 28)),
|
|
(datetime.date(2010, 2, 28), 2, datetime.date(2012, 2, 28)),
|
|
(datetime.date(2010, 2, 28), -1, datetime.date(2009, 2, 28)),
|
|
(datetime.date(2010, 2, 28), -2, datetime.date(2008, 2, 28)),
|
|
(datetime.date(2012, 2, 29), 2, datetime.date(2014, 3, 1)),
|
|
(datetime.date(2012, 2, 29), 4, datetime.date(2016, 2, 29)),
|
|
(datetime.date(2012, 2, 29), -2, datetime.date(2010, 2, 28)),
|
|
(datetime.date(2012, 2, 29), -4, datetime.date(2008, 2, 29)),
|
|
(datetime.date(2010, 3, 1), 1, datetime.date(2011, 3, 1)),
|
|
(datetime.date(2010, 3, 1), 2, datetime.date(2012, 3, 1)),
|
|
(datetime.date(2010, 3, 1), -1, datetime.date(2009, 3, 1)),
|
|
(datetime.date(2010, 3, 1), -2, datetime.date(2008, 3, 1)),
|
|
])
|
|
def test_diff_year(date, diff, expected):
|
|
assert cliutil.diff_year(date, diff) == expected
|
|
|
|
@pytest.mark.parametrize('cmd,expected', [
|
|
(['true'], True),
|
|
(['true', '--version'], True),
|
|
(['false'], False),
|
|
(['false', '--version'], False),
|
|
([str(testutil.TESTS_DIR)], False),
|
|
])
|
|
def test_can_run(cmd, expected):
|
|
assert cliutil.can_run(cmd) == expected
|
|
|
|
@pytest.mark.parametrize('choice', ArgChoices)
|
|
def test_enum_arg_enum_type(arg_enum, choice):
|
|
assert arg_enum.enum_type(choice.name) is choice
|
|
assert arg_enum.enum_type(choice.value) is choice
|
|
|
|
@pytest.mark.parametrize('arg', 'az\0')
|
|
def test_enum_arg_no_enum_match(arg_enum, arg):
|
|
with pytest.raises(ValueError):
|
|
arg_enum.enum_type(arg)
|
|
|
|
@pytest.mark.parametrize('choice', ArgChoices)
|
|
def test_enum_arg_value_type(arg_enum, choice):
|
|
assert arg_enum.value_type(choice.name) == choice.value
|
|
assert arg_enum.value_type(choice.value) == choice.value
|
|
|
|
@pytest.mark.parametrize('arg', 'az\0')
|
|
def test_enum_arg_no_value_match(arg_enum, arg):
|
|
with pytest.raises(ValueError):
|
|
arg_enum.value_type(arg)
|
|
|
|
def test_enum_arg_choices_str_defaults(arg_enum):
|
|
assert arg_enum.choices_str() == ', '.join(repr(c.value) for c in ArgChoices)
|
|
|
|
def test_enum_arg_choices_str_args(arg_enum):
|
|
sep = '/'
|
|
assert arg_enum.choices_str(sep, '{}') == sep.join(c.value for c in ArgChoices)
|