"""Test CLI utilities""" # Copyright © 2020 Brett Smith # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import argparse import datetime import errno import io import inspect import logging import os import re import sys import traceback import pytest from pathlib import Path from conservancy_beancount import cliutil FILE_NAMES = ['-foobar', '-foo.bin'] STREAM_PATHS = [None, Path('-')] class MockTraceback: def __init__(self, stack=None, index=0): if stack is None: stack = inspect.stack(context=False) self._stack = stack self._index = index frame_record = self._stack[self._index] self.tb_frame = frame_record.frame self.tb_lineno = frame_record.lineno @property def tb_next(self): try: return type(self)(self._stack, self._index + 1) except IndexError: return None @pytest.fixture(scope='module') def argparser(): parser = argparse.ArgumentParser(prog='test_cliutil') cliutil.add_loglevel_argument(parser) cliutil.add_version_argument(parser) return parser @pytest.mark.parametrize('path_name', FILE_NAMES) def test_bytes_output_path(path_name, tmp_path): path = tmp_path / path_name stream = io.BytesIO() actual = cliutil.bytes_output(path, stream) assert actual is not stream assert str(actual.name) == str(path) assert 'w' in actual.mode assert 'b' in actual.mode @pytest.mark.parametrize('path', STREAM_PATHS) def test_bytes_output_stream(path): stream = io.BytesIO() actual = cliutil.bytes_output(path, stream) assert actual is stream @pytest.mark.parametrize('year,month,day', [ (2000, 1, 1), (2016, 2, 29), (2020, 12, 31), ]) def test_date_arg_valid(year, month, day): expected = datetime.date(year, month, day) assert cliutil.date_arg(expected.isoformat()) == expected @pytest.mark.parametrize('arg', [ '2000', '20-02-12', '2019-02-29', 'two thousand', ]) def test_date_arg_invalid(arg): with pytest.raises(ValueError): cliutil.date_arg(arg) @pytest.mark.parametrize('year', [ 1990, 2000, 2009, ]) def test_year_or_date_arg_year(year): assert cliutil.year_or_date_arg(str(year)) == year @pytest.mark.parametrize('year,month,day', [ (2000, 1, 1), (2016, 2, 29), (2020, 12, 31), ]) def test_year_or_date_arg_date(year, month, day): expected = datetime.date(year, month, day) assert cliutil.year_or_date_arg(expected.isoformat()) == expected @pytest.mark.parametrize('arg', [ '-1', str(sys.maxsize), 'MMDVIII', '2019-02-29', ]) def test_year_or_date_arg_invalid(arg): with pytest.raises(ValueError): cliutil.year_or_date_arg(arg) @pytest.mark.parametrize('func_name', [ 'bytes_output', 'text_output', ]) def test_default_output(func_name): actual = getattr(cliutil, func_name)() assert actual.fileno() == sys.stdout.fileno() @pytest.mark.parametrize('path_name', FILE_NAMES) def test_text_output_path(path_name, tmp_path): path = tmp_path / path_name stream = io.StringIO() actual = cliutil.text_output(path, stream) assert actual is not stream assert str(actual.name) == str(path) assert 'w' in actual.mode @pytest.mark.parametrize('path', STREAM_PATHS) def test_text_output_stream(path): stream = io.StringIO() actual = cliutil.text_output(path, stream) assert actual is stream @pytest.mark.parametrize('errnum', [ errno.EACCES, errno.EPERM, errno.ENOENT, ]) def test_excepthook_oserror(errnum, caplog): error = OSError(errnum, os.strerror(errnum), 'TestFilename') with pytest.raises(SystemExit) as exc_check: cliutil.ExceptHook()(type(error), error, None) assert exc_check.value.args[0] == 4 assert caplog.records for log in caplog.records: assert log.levelname == 'CRITICAL' assert log.message == f"I/O error: {error.filename}: {error.strerror}" @pytest.mark.parametrize('exc_type', [ AttributeError, RuntimeError, ValueError, ]) def test_excepthook_bug(exc_type, caplog): error = exc_type("test message") with pytest.raises(SystemExit) as exc_check: cliutil.ExceptHook()(exc_type, error, None) assert exc_check.value.args[0] == 3 assert caplog.records for log in caplog.records: assert log.levelname == 'CRITICAL' assert log.message == f"internal {exc_type.__name__}: {error.args[0]}" def test_excepthook_traceback(caplog): error = KeyError('test') args = (type(error), error, MockTraceback()) caplog.set_level(logging.DEBUG) with pytest.raises(SystemExit) as exc_check: cliutil.ExceptHook()(*args) assert caplog.records assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) @pytest.mark.parametrize('arg,expected', [ ('debug', logging.DEBUG), ('info', logging.INFO), ('warning', logging.WARNING), ('warn', logging.WARNING), ('error', logging.ERROR), ('err', logging.ERROR), ('critical', logging.CRITICAL), ('crit', logging.CRITICAL), ]) def test_loglevel_argument(argparser, arg, expected): for method in ['lower', 'title', 'upper']: args = argparser.parse_args(['--loglevel', getattr(arg, method)()]) assert args.loglevel is expected def test_setup_logger(): stream = io.StringIO() logger = cliutil.setup_logger( 'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s', ) logger.critical("test crit") assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n" @pytest.mark.parametrize('arg', [ '--license', '--version', '--copyright', ]) def test_version_argument(argparser, capsys, arg): with pytest.raises(SystemExit) as exc_check: args = argparser.parse_args(['--version']) assert exc_check.value.args[0] == 0 stdout, _ = capsys.readouterr() lines = iter(stdout.splitlines()) assert re.match(r'^test_cliutil version \d+\.\d+\.\d+', next(lines, "")) @pytest.mark.parametrize('date,diff,expected', [ (datetime.date(2010, 2, 28), 0, datetime.date(2010, 2, 28)), (datetime.date(2010, 2, 28), 1, datetime.date(2011, 2, 28)), (datetime.date(2010, 2, 28), 2, datetime.date(2012, 2, 28)), (datetime.date(2010, 2, 28), -1, datetime.date(2009, 2, 28)), (datetime.date(2010, 2, 28), -2, datetime.date(2008, 2, 28)), (datetime.date(2012, 2, 29), 2, datetime.date(2014, 3, 1)), (datetime.date(2012, 2, 29), 4, datetime.date(2016, 2, 29)), (datetime.date(2012, 2, 29), -2, datetime.date(2010, 2, 28)), (datetime.date(2012, 2, 29), -4, datetime.date(2008, 2, 29)), (datetime.date(2010, 3, 1), 1, datetime.date(2011, 3, 1)), (datetime.date(2010, 3, 1), 2, datetime.date(2012, 3, 1)), (datetime.date(2010, 3, 1), -1, datetime.date(2009, 3, 1)), (datetime.date(2010, 3, 1), -2, datetime.date(2008, 3, 1)), ]) def test_diff_year(date, diff, expected): assert cliutil.diff_year(date, diff) == expected