cliutil: Better implementation of is_main_script.
The old one could return True if you called accrual.main() directly from one-off test scripts.
This commit is contained in:
		
							parent
							
								
									e07a47ec8f
								
							
						
					
					
						commit
						32b62df540
					
				
					 3 changed files with 21 additions and 6 deletions
				
			
		| 
						 | 
					@ -28,6 +28,8 @@ import sys
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import types
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    Any,
 | 
					    Any,
 | 
				
			||||||
    Iterable,
 | 
					    Iterable,
 | 
				
			||||||
| 
						 | 
					@ -134,10 +136,14 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 | 
				
			||||||
        help="Show program version and license information",
 | 
					        help="Show program version and license information",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_main_script() -> bool:
 | 
					def is_main_script(prog_name: str) -> bool:
 | 
				
			||||||
    """Return true if the caller is the "main" program."""
 | 
					    """Return true if the caller is the "main" program."""
 | 
				
			||||||
    stack = inspect.stack(context=False)
 | 
					    stack = iter(inspect.stack(context=False))
 | 
				
			||||||
    return len(stack) <= 3 and stack[-1].function.startswith('<')
 | 
					    next(stack)  # Discard the frame for calling this function
 | 
				
			||||||
 | 
					    caller_filename = next(stack).filename
 | 
				
			||||||
 | 
					    return all(frame.filename == caller_filename
 | 
				
			||||||
 | 
					               or Path(frame.filename).stem == prog_name
 | 
				
			||||||
 | 
					               for frame in stack)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def setup_logger(logger: Union[str, logging.Logger]='',
 | 
					def setup_logger(logger: Union[str, logging.Logger]='',
 | 
				
			||||||
                 loglevel: int=logging.INFO,
 | 
					                 loglevel: int=logging.INFO,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -391,7 +391,7 @@ def main(arglist: Optional[Sequence[str]]=None,
 | 
				
			||||||
         stderr: TextIO=sys.stderr,
 | 
					         stderr: TextIO=sys.stderr,
 | 
				
			||||||
         config: Optional[configmod.Config]=None,
 | 
					         config: Optional[configmod.Config]=None,
 | 
				
			||||||
) -> int:
 | 
					) -> int:
 | 
				
			||||||
    if cliutil.is_main_script():
 | 
					    if cliutil.is_main_script(PROGNAME):
 | 
				
			||||||
        global logger
 | 
					        global logger
 | 
				
			||||||
        logger = logging.getLogger(PROGNAME)
 | 
					        logger = logging.getLogger(PROGNAME)
 | 
				
			||||||
        sys.excepthook = cliutil.ExceptHook(logger)
 | 
					        sys.excepthook = cliutil.ExceptHook(logger)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,11 @@ import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from conservancy_beancount import cliutil
 | 
					from conservancy_beancount import cliutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlwaysEqual:
 | 
				
			||||||
 | 
					    def __eq__(self, other):
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MockTraceback:
 | 
					class MockTraceback:
 | 
				
			||||||
    def __init__(self, stack=None, index=0):
 | 
					    def __init__(self, stack=None, index=0):
 | 
				
			||||||
        if stack is None:
 | 
					        if stack is None:
 | 
				
			||||||
| 
						 | 
					@ -91,8 +96,12 @@ def test_excepthook_traceback(caplog):
 | 
				
			||||||
    assert caplog.records
 | 
					    assert caplog.records
 | 
				
			||||||
    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 | 
					    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_is_main_script():
 | 
					@pytest.mark.parametrize('prog_name,expected', [
 | 
				
			||||||
    assert not cliutil.is_main_script()
 | 
					    ('', False),
 | 
				
			||||||
 | 
					    (AlwaysEqual(), True),
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					def test_is_main_script(prog_name, expected):
 | 
				
			||||||
 | 
					    assert cliutil.is_main_script(prog_name) == expected
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('arg,expected', [
 | 
					@pytest.mark.parametrize('arg,expected', [
 | 
				
			||||||
    ('debug', logging.DEBUG),
 | 
					    ('debug', logging.DEBUG),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue