cliutil: Add make_entry_point() function.
This provides better logging setup, reduces the amount of boilerplate in main, and replaces is_main_script().
This commit is contained in:
parent
2a33e17892
commit
cd578289c4
6 changed files with 60 additions and 43 deletions
|
@ -18,7 +18,6 @@ along with this program. If not, see <https://www.gnu.org/licenses/>."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
|
@ -232,29 +231,50 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
|
||||||
help="Show program version and license information",
|
help="Show program version and license information",
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_main_script(prog_name: str) -> bool:
|
def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]:
|
||||||
"""Return true if the caller is the "main" program."""
|
"""Create an entry_point function for a tool
|
||||||
stack = iter(inspect.stack(context=False))
|
|
||||||
next(stack) # Discard the frame for calling this function
|
The returned function is suitable for use as an entry_point in setup.py.
|
||||||
caller_filename = next(stack).filename
|
It sets up the root logger and excepthook, then calls the module's main
|
||||||
return all(frame.filename == caller_filename
|
function.
|
||||||
or Path(frame.filename).stem == prog_name
|
"""
|
||||||
for frame in stack)
|
def entry_point(): # type:ignore
|
||||||
|
prog_mod = sys.modules[mod_name]
|
||||||
|
setup_logger()
|
||||||
|
prog_mod.logger = logging.getLogger(prog_name)
|
||||||
|
sys.excepthook = ExceptHook(prog_mod.logger)
|
||||||
|
return prog_mod.main()
|
||||||
|
return entry_point
|
||||||
|
|
||||||
def setup_logger(logger: Union[str, logging.Logger]='',
|
def setup_logger(logger: Union[str, logging.Logger]='',
|
||||||
loglevel: int=logging.INFO,
|
|
||||||
stream: TextIO=sys.stderr,
|
stream: TextIO=sys.stderr,
|
||||||
fmt: str='%(name)s: %(levelname)s: %(message)s',
|
fmt: str='%(name)s: %(levelname)s: %(message)s',
|
||||||
) -> logging.Logger:
|
) -> logging.Logger:
|
||||||
|
"""Set up a logger with a StreamHandler with the given format"""
|
||||||
if isinstance(logger, str):
|
if isinstance(logger, str):
|
||||||
logger = logging.getLogger(logger)
|
logger = logging.getLogger(logger)
|
||||||
formatter = logging.Formatter(fmt)
|
formatter = logging.Formatter(fmt)
|
||||||
handler = logging.StreamHandler(stream)
|
handler = logging.StreamHandler(stream)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
logger.setLevel(loglevel)
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None:
|
||||||
|
"""Set the loglevel for a tool or module
|
||||||
|
|
||||||
|
If the given logger is not under a hierarchy, this function sets the
|
||||||
|
loglevel for the root logger, along with some specific levels for libraries
|
||||||
|
used by reporting tools. Otherwise, it's the same as
|
||||||
|
``logger.setLevel(loglevel)``.
|
||||||
|
"""
|
||||||
|
if '.' not in logger.name:
|
||||||
|
logger = logging.getLogger()
|
||||||
|
if loglevel <= logging.DEBUG:
|
||||||
|
# At the debug level, the rt module logs the full body of every
|
||||||
|
# request and response. That's too much.
|
||||||
|
logging.getLogger('rt.rt').setLevel(logging.INFO)
|
||||||
|
logger.setLevel(loglevel)
|
||||||
|
|
||||||
def bytes_output(path: Optional[Path]=None,
|
def bytes_output(path: Optional[Path]=None,
|
||||||
default: OutputFile=sys.stdout,
|
default: OutputFile=sys.stdout,
|
||||||
mode: str='w',
|
mode: str='w',
|
||||||
|
|
|
@ -663,12 +663,8 @@ def main(arglist: Optional[Sequence[str]]=None,
|
||||||
stderr: TextIO=sys.stderr,
|
stderr: TextIO=sys.stderr,
|
||||||
config: Optional[configmod.Config]=None,
|
config: Optional[configmod.Config]=None,
|
||||||
) -> int:
|
) -> int:
|
||||||
if cliutil.is_main_script(PROGNAME):
|
|
||||||
global logger
|
|
||||||
logger = logging.getLogger(PROGNAME)
|
|
||||||
sys.excepthook = cliutil.ExceptHook(logger)
|
|
||||||
args = parse_arguments(arglist)
|
args = parse_arguments(arglist)
|
||||||
cliutil.setup_logger(logger, args.loglevel, stderr)
|
cliutil.set_loglevel(logger, args.loglevel)
|
||||||
if config is None:
|
if config is None:
|
||||||
config = configmod.Config()
|
config = configmod.Config()
|
||||||
config.load_file()
|
config.load_file()
|
||||||
|
@ -753,5 +749,7 @@ def main(arglist: Optional[Sequence[str]]=None,
|
||||||
report.run(groups)
|
report.run(groups)
|
||||||
return 0 if returncode == 0 else 16 + returncode
|
return 0 if returncode == 0 else 16 + returncode
|
||||||
|
|
||||||
|
entry_point = cliutil.make_entry_point(__name__, PROGNAME)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
exit(main())
|
exit(entry_point())
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -35,7 +35,7 @@ setup(
|
||||||
],
|
],
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
'accrual-report = conservancy_beancount.reports.accrual:main',
|
'accrual-report = conservancy_beancount.reports.accrual:entry_point',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,11 +33,6 @@ from conservancy_beancount import cliutil
|
||||||
FILE_NAMES = ['-foobar', '-foo.bin']
|
FILE_NAMES = ['-foobar', '-foo.bin']
|
||||||
STREAM_PATHS = [None, Path('-')]
|
STREAM_PATHS = [None, Path('-')]
|
||||||
|
|
||||||
class AlwaysEqual:
|
|
||||||
def __eq__(self, other):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class MockTraceback:
|
class MockTraceback:
|
||||||
def __init__(self, stack=None, index=0):
|
def __init__(self, stack=None, index=0):
|
||||||
if stack is None:
|
if stack is None:
|
||||||
|
@ -141,13 +136,6 @@ def test_excepthook_traceback(caplog):
|
||||||
assert caplog.records
|
assert caplog.records
|
||||||
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
|
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
|
||||||
|
|
||||||
@pytest.mark.parametrize('prog_name,expected', [
|
|
||||||
('', False),
|
|
||||||
(AlwaysEqual(), True),
|
|
||||||
])
|
|
||||||
def test_is_main_script(prog_name, expected):
|
|
||||||
assert cliutil.is_main_script(prog_name) == expected
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('arg,expected', [
|
@pytest.mark.parametrize('arg,expected', [
|
||||||
('debug', logging.DEBUG),
|
('debug', logging.DEBUG),
|
||||||
('info', logging.INFO),
|
('info', logging.INFO),
|
||||||
|
@ -166,11 +154,10 @@ def test_loglevel_argument(argparser, arg, expected):
|
||||||
def test_setup_logger():
|
def test_setup_logger():
|
||||||
stream = io.StringIO()
|
stream = io.StringIO()
|
||||||
logger = cliutil.setup_logger(
|
logger = cliutil.setup_logger(
|
||||||
'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s',
|
'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
|
||||||
)
|
)
|
||||||
logger.debug("test debug")
|
logger.critical("test crit")
|
||||||
logger.info("test info")
|
assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
|
||||||
assert stream.getvalue() == "test_cliutil INFO: test info\n"
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('arg', [
|
@pytest.mark.parametrize('arg', [
|
||||||
'--license',
|
'--license',
|
||||||
|
|
|
@ -590,12 +590,13 @@ def run_main(arglist, config=None):
|
||||||
retcode = accrual.main(arglist, output, errors, config)
|
retcode = accrual.main(arglist, output, errors, config)
|
||||||
return retcode, output, errors
|
return retcode, output, errors
|
||||||
|
|
||||||
def check_main_fails(arglist, config, error_flags, error_patterns):
|
def check_main_fails(arglist, config, error_flags):
|
||||||
retcode, output, errors = run_main(arglist, config)
|
retcode, output, errors = run_main(arglist, config)
|
||||||
assert retcode > 16
|
assert retcode > 16
|
||||||
assert (retcode - 16) & error_flags
|
assert (retcode - 16) & error_flags
|
||||||
check_output(errors, error_patterns)
|
|
||||||
assert not output.getvalue()
|
assert not output.getvalue()
|
||||||
|
errors.seek(0)
|
||||||
|
return errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('arglist', [
|
@pytest.mark.parametrize('arglist', [
|
||||||
['--report-type=balance', 'entity=EarlyBird'],
|
['--report-type=balance', 'entity=EarlyBird'],
|
||||||
|
@ -686,7 +687,8 @@ def test_main_aging_report(tmp_path, arglist):
|
||||||
check_aging_ods(ods_file, None, recv_rows, pay_rows)
|
check_aging_ods(ods_file, None, recv_rows, pay_rows)
|
||||||
|
|
||||||
def test_main_no_books():
|
def test_main_no_books():
|
||||||
check_main_fails([], testutil.TestConfig(), 1 | 8, [
|
errors = check_main_fails([], testutil.TestConfig(), 1 | 8)
|
||||||
|
testutil.check_lines_match(iter(errors), [
|
||||||
r':[01]: +no books to load in configuration\b',
|
r':[01]: +no books to load in configuration\b',
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -695,15 +697,17 @@ def test_main_no_books():
|
||||||
['505/99999'],
|
['505/99999'],
|
||||||
['entity=NonExistent'],
|
['entity=NonExistent'],
|
||||||
])
|
])
|
||||||
def test_main_no_matches(arglist):
|
def test_main_no_matches(arglist, caplog):
|
||||||
check_main_fails(arglist, None, 8, [
|
check_main_fails(arglist, None, 8)
|
||||||
r': WARNING: no matching entries found to report$',
|
testutil.check_logs_match(caplog, [
|
||||||
|
('WARNING', 'no matching entries found to report'),
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_main_no_rt():
|
def test_main_no_rt(caplog):
|
||||||
config = testutil.TestConfig(
|
config = testutil.TestConfig(
|
||||||
books_path=testutil.test_path('books/accruals.beancount'),
|
books_path=testutil.test_path('books/accruals.beancount'),
|
||||||
)
|
)
|
||||||
check_main_fails(['-t', 'out'], config, 4, [
|
check_main_fails(['-t', 'out'], config, 4)
|
||||||
r': ERROR: unable to generate outgoing report: RT client is required\b',
|
testutil.check_logs_match(caplog, [
|
||||||
|
('ERROR', 'unable to generate outgoing report: RT client is required'),
|
||||||
])
|
])
|
||||||
|
|
|
@ -69,6 +69,14 @@ def check_lines_match(lines, expect_patterns, source='output'):
|
||||||
assert any(re.search(pattern, line) for line in lines), \
|
assert any(re.search(pattern, line) for line in lines), \
|
||||||
f"{pattern!r} not found in {source}"
|
f"{pattern!r} not found in {source}"
|
||||||
|
|
||||||
|
def check_logs_match(caplog, expected):
|
||||||
|
records = iter(caplog.records)
|
||||||
|
for exp_level, exp_msg in expected:
|
||||||
|
exp_level = exp_level.upper()
|
||||||
|
assert any(
|
||||||
|
log.levelname == exp_level and log.message == exp_msg for log in records
|
||||||
|
), f"{exp_level} log {exp_msg!r} not found"
|
||||||
|
|
||||||
def check_post_meta(txn, *expected_meta, default=None):
|
def check_post_meta(txn, *expected_meta, default=None):
|
||||||
assert len(txn.postings) == len(expected_meta)
|
assert len(txn.postings) == len(expected_meta)
|
||||||
for post, expected in zip(txn.postings, expected_meta):
|
for post, expected in zip(txn.postings, expected_meta):
|
||||||
|
|
Loading…
Reference in a new issue