cliutil: Better implementation of is_main_script.
The old one could return True if you called accrual.main() directly from one-off test scripts.
This commit is contained in:
parent
e07a47ec8f
commit
32b62df540
3 changed files with 21 additions and 6 deletions
|
@ -28,6 +28,8 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -134,10 +136,14 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
|
||||||
help="Show program version and license information",
|
help="Show program version and license information",
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_main_script() -> bool:
|
def is_main_script(prog_name: str) -> bool:
|
||||||
"""Return true if the caller is the "main" program."""
|
"""Return true if the caller is the "main" program."""
|
||||||
stack = inspect.stack(context=False)
|
stack = iter(inspect.stack(context=False))
|
||||||
return len(stack) <= 3 and stack[-1].function.startswith('<')
|
next(stack) # Discard the frame for calling this function
|
||||||
|
caller_filename = next(stack).filename
|
||||||
|
return all(frame.filename == caller_filename
|
||||||
|
or Path(frame.filename).stem == prog_name
|
||||||
|
for frame in stack)
|
||||||
|
|
||||||
def setup_logger(logger: Union[str, logging.Logger]='',
|
def setup_logger(logger: Union[str, logging.Logger]='',
|
||||||
loglevel: int=logging.INFO,
|
loglevel: int=logging.INFO,
|
||||||
|
|
|
@ -391,7 +391,7 @@ def main(arglist: Optional[Sequence[str]]=None,
|
||||||
stderr: TextIO=sys.stderr,
|
stderr: TextIO=sys.stderr,
|
||||||
config: Optional[configmod.Config]=None,
|
config: Optional[configmod.Config]=None,
|
||||||
) -> int:
|
) -> int:
|
||||||
if cliutil.is_main_script():
|
if cliutil.is_main_script(PROGNAME):
|
||||||
global logger
|
global logger
|
||||||
logger = logging.getLogger(PROGNAME)
|
logger = logging.getLogger(PROGNAME)
|
||||||
sys.excepthook = cliutil.ExceptHook(logger)
|
sys.excepthook = cliutil.ExceptHook(logger)
|
||||||
|
|
|
@ -27,6 +27,11 @@ import pytest
|
||||||
|
|
||||||
from conservancy_beancount import cliutil
|
from conservancy_beancount import cliutil
|
||||||
|
|
||||||
|
class AlwaysEqual:
|
||||||
|
def __eq__(self, other):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class MockTraceback:
|
class MockTraceback:
|
||||||
def __init__(self, stack=None, index=0):
|
def __init__(self, stack=None, index=0):
|
||||||
if stack is None:
|
if stack is None:
|
||||||
|
@ -91,8 +96,12 @@ def test_excepthook_traceback(caplog):
|
||||||
assert caplog.records
|
assert caplog.records
|
||||||
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
|
assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
|
||||||
|
|
||||||
def test_is_main_script():
|
@pytest.mark.parametrize('prog_name,expected', [
|
||||||
assert not cliutil.is_main_script()
|
('', False),
|
||||||
|
(AlwaysEqual(), True),
|
||||||
|
])
|
||||||
|
def test_is_main_script(prog_name, expected):
|
||||||
|
assert cliutil.is_main_script(prog_name) == expected
|
||||||
|
|
||||||
@pytest.mark.parametrize('arg,expected', [
|
@pytest.mark.parametrize('arg,expected', [
|
||||||
('debug', logging.DEBUG),
|
('debug', logging.DEBUG),
|
||||||
|
|
Loading…
Reference in a new issue