cliutil: Better implementation of is_main_script.

The old one could return True if you called accrual.main()
directly from one-off test scripts.
This commit is contained in:
Brett Smith 2020-05-29 23:39:27 -04:00
parent e07a47ec8f
commit 32b62df540
3 changed files with 21 additions and 6 deletions

View file

@ -28,6 +28,8 @@ import sys
import traceback import traceback
import types import types
from pathlib import Path
from typing import ( from typing import (
Any, Any,
Iterable, Iterable,
@ -134,10 +136,14 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
help="Show program version and license information", help="Show program version and license information",
) )
def is_main_script() -> bool: def is_main_script(prog_name: str) -> bool:
"""Return true if the caller is the "main" program.""" """Return true if the caller is the "main" program."""
stack = inspect.stack(context=False) stack = iter(inspect.stack(context=False))
return len(stack) <= 3 and stack[-1].function.startswith('<') next(stack) # Discard the frame for calling this function
caller_filename = next(stack).filename
return all(frame.filename == caller_filename
or Path(frame.filename).stem == prog_name
for frame in stack)
def setup_logger(logger: Union[str, logging.Logger]='', def setup_logger(logger: Union[str, logging.Logger]='',
loglevel: int=logging.INFO, loglevel: int=logging.INFO,

View file

@ -391,7 +391,7 @@ def main(arglist: Optional[Sequence[str]]=None,
stderr: TextIO=sys.stderr, stderr: TextIO=sys.stderr,
config: Optional[configmod.Config]=None, config: Optional[configmod.Config]=None,
) -> int: ) -> int:
if cliutil.is_main_script(): if cliutil.is_main_script(PROGNAME):
global logger global logger
logger = logging.getLogger(PROGNAME) logger = logging.getLogger(PROGNAME)
sys.excepthook = cliutil.ExceptHook(logger) sys.excepthook = cliutil.ExceptHook(logger)

View file

@ -27,6 +27,11 @@ import pytest
from conservancy_beancount import cliutil from conservancy_beancount import cliutil
class AlwaysEqual:
def __eq__(self, other):
return True
class MockTraceback: class MockTraceback:
def __init__(self, stack=None, index=0): def __init__(self, stack=None, index=0):
if stack is None: if stack is None:
@ -91,8 +96,12 @@ def test_excepthook_traceback(caplog):
assert caplog.records assert caplog.records
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args)) assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
def test_is_main_script(): @pytest.mark.parametrize('prog_name,expected', [
assert not cliutil.is_main_script() ('', False),
(AlwaysEqual(), True),
])
def test_is_main_script(prog_name, expected):
assert cliutil.is_main_script(prog_name) == expected
@pytest.mark.parametrize('arg,expected', [ @pytest.mark.parametrize('arg,expected', [
('debug', logging.DEBUG), ('debug', logging.DEBUG),