cliutil: Add make_entry_point() function.
This provides better logging setup, reduces the amount of boilerplate in main, and replaces is_main_script().
This commit is contained in:
		
							parent
							
								
									2a33e17892
								
							
						
					
					
						commit
						cd578289c4
					
				
					 6 changed files with 60 additions and 43 deletions
				
			
		| 
						 | 
					@ -18,7 +18,6 @@ along with this program.  If not, see <https://www.gnu.org/licenses/>."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
import inspect
 | 
					 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import operator
 | 
					import operator
 | 
				
			||||||
| 
						 | 
					@ -232,29 +231,50 @@ def add_version_argument(parser: argparse.ArgumentParser) -> argparse.Action:
 | 
				
			||||||
        help="Show program version and license information",
 | 
					        help="Show program version and license information",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_main_script(prog_name: str) -> bool:
 | 
					def make_entry_point(mod_name: str, prog_name: str=sys.argv[0]) -> Callable[[], int]:
 | 
				
			||||||
    """Return true if the caller is the "main" program."""
 | 
					    """Create an entry_point function for a tool
 | 
				
			||||||
    stack = iter(inspect.stack(context=False))
 | 
					
 | 
				
			||||||
    next(stack)  # Discard the frame for calling this function
 | 
					    The returned function is suitable for use as an entry_point in setup.py.
 | 
				
			||||||
    caller_filename = next(stack).filename
 | 
					    It sets up the root logger and excepthook, then calls the module's main
 | 
				
			||||||
    return all(frame.filename == caller_filename
 | 
					    function.
 | 
				
			||||||
               or Path(frame.filename).stem == prog_name
 | 
					    """
 | 
				
			||||||
               for frame in stack)
 | 
					    def entry_point():  # type:ignore
 | 
				
			||||||
 | 
					        prog_mod = sys.modules[mod_name]
 | 
				
			||||||
 | 
					        setup_logger()
 | 
				
			||||||
 | 
					        prog_mod.logger = logging.getLogger(prog_name)
 | 
				
			||||||
 | 
					        sys.excepthook = ExceptHook(prog_mod.logger)
 | 
				
			||||||
 | 
					        return prog_mod.main()
 | 
				
			||||||
 | 
					    return entry_point
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def setup_logger(logger: Union[str, logging.Logger]='',
 | 
					def setup_logger(logger: Union[str, logging.Logger]='',
 | 
				
			||||||
                 loglevel: int=logging.INFO,
 | 
					 | 
				
			||||||
                 stream: TextIO=sys.stderr,
 | 
					                 stream: TextIO=sys.stderr,
 | 
				
			||||||
                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 | 
					                 fmt: str='%(name)s: %(levelname)s: %(message)s',
 | 
				
			||||||
) -> logging.Logger:
 | 
					) -> logging.Logger:
 | 
				
			||||||
 | 
					    """Set up a logger with a StreamHandler with the given format"""
 | 
				
			||||||
    if isinstance(logger, str):
 | 
					    if isinstance(logger, str):
 | 
				
			||||||
        logger = logging.getLogger(logger)
 | 
					        logger = logging.getLogger(logger)
 | 
				
			||||||
    formatter = logging.Formatter(fmt)
 | 
					    formatter = logging.Formatter(fmt)
 | 
				
			||||||
    handler = logging.StreamHandler(stream)
 | 
					    handler = logging.StreamHandler(stream)
 | 
				
			||||||
    handler.setFormatter(formatter)
 | 
					    handler.setFormatter(formatter)
 | 
				
			||||||
    logger.addHandler(handler)
 | 
					    logger.addHandler(handler)
 | 
				
			||||||
    logger.setLevel(loglevel)
 | 
					 | 
				
			||||||
    return logger
 | 
					    return logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def set_loglevel(logger: logging.Logger, loglevel: int=logging.INFO) -> None:
 | 
				
			||||||
 | 
					    """Set the loglevel for a tool or module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If the given logger is not under a hierarchy, this function sets the
 | 
				
			||||||
 | 
					    loglevel for the root logger, along with some specific levels for libraries
 | 
				
			||||||
 | 
					    used by reporting tools. Otherwise, it's the same as
 | 
				
			||||||
 | 
					    ``logger.setLevel(loglevel)``.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if '.' not in logger.name:
 | 
				
			||||||
 | 
					        logger = logging.getLogger()
 | 
				
			||||||
 | 
					        if loglevel <= logging.DEBUG:
 | 
				
			||||||
 | 
					            # At the debug level, the rt module logs the full body of every
 | 
				
			||||||
 | 
					            # request and response. That's too much.
 | 
				
			||||||
 | 
					            logging.getLogger('rt.rt').setLevel(logging.INFO)
 | 
				
			||||||
 | 
					    logger.setLevel(loglevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def bytes_output(path: Optional[Path]=None,
 | 
					def bytes_output(path: Optional[Path]=None,
 | 
				
			||||||
                 default: OutputFile=sys.stdout,
 | 
					                 default: OutputFile=sys.stdout,
 | 
				
			||||||
                 mode: str='w',
 | 
					                 mode: str='w',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -663,12 +663,8 @@ def main(arglist: Optional[Sequence[str]]=None,
 | 
				
			||||||
         stderr: TextIO=sys.stderr,
 | 
					         stderr: TextIO=sys.stderr,
 | 
				
			||||||
         config: Optional[configmod.Config]=None,
 | 
					         config: Optional[configmod.Config]=None,
 | 
				
			||||||
) -> int:
 | 
					) -> int:
 | 
				
			||||||
    if cliutil.is_main_script(PROGNAME):
 | 
					 | 
				
			||||||
        global logger
 | 
					 | 
				
			||||||
        logger = logging.getLogger(PROGNAME)
 | 
					 | 
				
			||||||
        sys.excepthook = cliutil.ExceptHook(logger)
 | 
					 | 
				
			||||||
    args = parse_arguments(arglist)
 | 
					    args = parse_arguments(arglist)
 | 
				
			||||||
    cliutil.setup_logger(logger, args.loglevel, stderr)
 | 
					    cliutil.set_loglevel(logger, args.loglevel)
 | 
				
			||||||
    if config is None:
 | 
					    if config is None:
 | 
				
			||||||
        config = configmod.Config()
 | 
					        config = configmod.Config()
 | 
				
			||||||
        config.load_file()
 | 
					        config.load_file()
 | 
				
			||||||
| 
						 | 
					@ -753,5 +749,7 @@ def main(arglist: Optional[Sequence[str]]=None,
 | 
				
			||||||
        report.run(groups)
 | 
					        report.run(groups)
 | 
				
			||||||
    return 0 if returncode == 0 else 16 + returncode
 | 
					    return 0 if returncode == 0 else 16 + returncode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					entry_point = cliutil.make_entry_point(__name__, PROGNAME)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    exit(main())
 | 
					    exit(entry_point())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										2
									
								
								setup.py
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
										
									
									
									
								
							| 
						 | 
					@ -35,7 +35,7 @@ setup(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    entry_points={
 | 
					    entry_points={
 | 
				
			||||||
        'console_scripts': [
 | 
					        'console_scripts': [
 | 
				
			||||||
            'accrual-report = conservancy_beancount.reports.accrual:main',
 | 
					            'accrual-report = conservancy_beancount.reports.accrual:entry_point',
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,11 +33,6 @@ from conservancy_beancount import cliutil
 | 
				
			||||||
FILE_NAMES = ['-foobar', '-foo.bin']
 | 
					FILE_NAMES = ['-foobar', '-foo.bin']
 | 
				
			||||||
STREAM_PATHS = [None, Path('-')]
 | 
					STREAM_PATHS = [None, Path('-')]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AlwaysEqual:
 | 
					 | 
				
			||||||
    def __eq__(self, other):
 | 
					 | 
				
			||||||
        return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MockTraceback:
 | 
					class MockTraceback:
 | 
				
			||||||
    def __init__(self, stack=None, index=0):
 | 
					    def __init__(self, stack=None, index=0):
 | 
				
			||||||
        if stack is None:
 | 
					        if stack is None:
 | 
				
			||||||
| 
						 | 
					@ -141,13 +136,6 @@ def test_excepthook_traceback(caplog):
 | 
				
			||||||
    assert caplog.records
 | 
					    assert caplog.records
 | 
				
			||||||
    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 | 
					    assert caplog.records[-1].message == ''.join(traceback.format_exception(*args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('prog_name,expected', [
 | 
					 | 
				
			||||||
    ('', False),
 | 
					 | 
				
			||||||
    (AlwaysEqual(), True),
 | 
					 | 
				
			||||||
])
 | 
					 | 
				
			||||||
def test_is_main_script(prog_name, expected):
 | 
					 | 
				
			||||||
    assert cliutil.is_main_script(prog_name) == expected
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.parametrize('arg,expected', [
 | 
					@pytest.mark.parametrize('arg,expected', [
 | 
				
			||||||
    ('debug', logging.DEBUG),
 | 
					    ('debug', logging.DEBUG),
 | 
				
			||||||
    ('info', logging.INFO),
 | 
					    ('info', logging.INFO),
 | 
				
			||||||
| 
						 | 
					@ -166,11 +154,10 @@ def test_loglevel_argument(argparser, arg, expected):
 | 
				
			||||||
def test_setup_logger():
 | 
					def test_setup_logger():
 | 
				
			||||||
    stream = io.StringIO()
 | 
					    stream = io.StringIO()
 | 
				
			||||||
    logger = cliutil.setup_logger(
 | 
					    logger = cliutil.setup_logger(
 | 
				
			||||||
        'test_cliutil', logging.INFO, stream, '%(name)s %(levelname)s: %(message)s',
 | 
					        'test_cliutil', stream, '%(name)s %(levelname)s: %(message)s',
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    logger.debug("test debug")
 | 
					    logger.critical("test crit")
 | 
				
			||||||
    logger.info("test info")
 | 
					    assert stream.getvalue() == "test_cliutil CRITICAL: test crit\n"
 | 
				
			||||||
    assert stream.getvalue() == "test_cliutil INFO: test info\n"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('arg', [
 | 
					@pytest.mark.parametrize('arg', [
 | 
				
			||||||
    '--license',
 | 
					    '--license',
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -590,12 +590,13 @@ def run_main(arglist, config=None):
 | 
				
			||||||
    retcode = accrual.main(arglist, output, errors, config)
 | 
					    retcode = accrual.main(arglist, output, errors, config)
 | 
				
			||||||
    return retcode, output, errors
 | 
					    return retcode, output, errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_main_fails(arglist, config, error_flags, error_patterns):
 | 
					def check_main_fails(arglist, config, error_flags):
 | 
				
			||||||
    retcode, output, errors = run_main(arglist, config)
 | 
					    retcode, output, errors = run_main(arglist, config)
 | 
				
			||||||
    assert retcode > 16
 | 
					    assert retcode > 16
 | 
				
			||||||
    assert (retcode - 16) & error_flags
 | 
					    assert (retcode - 16) & error_flags
 | 
				
			||||||
    check_output(errors, error_patterns)
 | 
					 | 
				
			||||||
    assert not output.getvalue()
 | 
					    assert not output.getvalue()
 | 
				
			||||||
 | 
					    errors.seek(0)
 | 
				
			||||||
 | 
					    return errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('arglist', [
 | 
					@pytest.mark.parametrize('arglist', [
 | 
				
			||||||
    ['--report-type=balance', 'entity=EarlyBird'],
 | 
					    ['--report-type=balance', 'entity=EarlyBird'],
 | 
				
			||||||
| 
						 | 
					@ -686,7 +687,8 @@ def test_main_aging_report(tmp_path, arglist):
 | 
				
			||||||
        check_aging_ods(ods_file, None, recv_rows, pay_rows)
 | 
					        check_aging_ods(ods_file, None, recv_rows, pay_rows)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_main_no_books():
 | 
					def test_main_no_books():
 | 
				
			||||||
    check_main_fails([], testutil.TestConfig(), 1 | 8, [
 | 
					    errors = check_main_fails([], testutil.TestConfig(), 1 | 8)
 | 
				
			||||||
 | 
					    testutil.check_lines_match(iter(errors), [
 | 
				
			||||||
        r':[01]: +no books to load in configuration\b',
 | 
					        r':[01]: +no books to load in configuration\b',
 | 
				
			||||||
    ])
 | 
					    ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -695,15 +697,17 @@ def test_main_no_books():
 | 
				
			||||||
    ['505/99999'],
 | 
					    ['505/99999'],
 | 
				
			||||||
    ['entity=NonExistent'],
 | 
					    ['entity=NonExistent'],
 | 
				
			||||||
])
 | 
					])
 | 
				
			||||||
def test_main_no_matches(arglist):
 | 
					def test_main_no_matches(arglist, caplog):
 | 
				
			||||||
    check_main_fails(arglist, None, 8, [
 | 
					    check_main_fails(arglist, None, 8)
 | 
				
			||||||
        r': WARNING: no matching entries found to report$',
 | 
					    testutil.check_logs_match(caplog, [
 | 
				
			||||||
 | 
					        ('WARNING', 'no matching entries found to report'),
 | 
				
			||||||
    ])
 | 
					    ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_main_no_rt():
 | 
					def test_main_no_rt(caplog):
 | 
				
			||||||
    config = testutil.TestConfig(
 | 
					    config = testutil.TestConfig(
 | 
				
			||||||
        books_path=testutil.test_path('books/accruals.beancount'),
 | 
					        books_path=testutil.test_path('books/accruals.beancount'),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    check_main_fails(['-t', 'out'], config, 4, [
 | 
					    check_main_fails(['-t', 'out'], config, 4)
 | 
				
			||||||
        r': ERROR: unable to generate outgoing report: RT client is required\b',
 | 
					    testutil.check_logs_match(caplog, [
 | 
				
			||||||
 | 
					        ('ERROR', 'unable to generate outgoing report: RT client is required'),
 | 
				
			||||||
    ])
 | 
					    ])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -69,6 +69,14 @@ def check_lines_match(lines, expect_patterns, source='output'):
 | 
				
			||||||
        assert any(re.search(pattern, line) for line in lines), \
 | 
					        assert any(re.search(pattern, line) for line in lines), \
 | 
				
			||||||
            f"{pattern!r} not found in {source}"
 | 
					            f"{pattern!r} not found in {source}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def check_logs_match(caplog, expected):
 | 
				
			||||||
 | 
					    records = iter(caplog.records)
 | 
				
			||||||
 | 
					    for exp_level, exp_msg in expected:
 | 
				
			||||||
 | 
					        exp_level = exp_level.upper()
 | 
				
			||||||
 | 
					        assert any(
 | 
				
			||||||
 | 
					            log.levelname == exp_level and log.message == exp_msg for log in records
 | 
				
			||||||
 | 
					        ), f"{exp_level} log {exp_msg!r} not found"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_post_meta(txn, *expected_meta, default=None):
 | 
					def check_post_meta(txn, *expected_meta, default=None):
 | 
				
			||||||
    assert len(txn.postings) == len(expected_meta)
 | 
					    assert len(txn.postings) == len(expected_meta)
 | 
				
			||||||
    for post, expected in zip(txn.postings, expected_meta):
 | 
					    for post, expected in zip(txn.postings, expected_meta):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue