cliutil: Add bytes_output() and text_output() functions.
This commit is contained in:
parent
04c804a506
commit
2b5cb0eca6
3 changed files with 106 additions and 32 deletions
|
@ -19,6 +19,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>."""
|
|||
import argparse
|
||||
import enum
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
|
@ -36,8 +37,11 @@ from . import filters
|
|||
from . import rtutil
|
||||
|
||||
from typing import (
|
||||
cast,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
IO,
|
||||
Iterable,
|
||||
NamedTuple,
|
||||
NoReturn,
|
||||
|
@ -51,6 +55,9 @@ from .beancount_types import (
|
|||
MetaKey,
|
||||
)
|
||||
|
||||
OutputFile = Union[int, IO]
|
||||
|
||||
STDSTREAM_PATH = Path('-')
|
||||
VERSION = pkg_resources.require(PKGNAME)[0].version
|
||||
|
||||
class ExceptHook:
|
||||
|
@ -247,3 +254,51 @@ def setup_logger(logger: Union[str, logging.Logger]='',
|
|||
logger.addHandler(handler)
|
||||
logger.setLevel(loglevel)
|
||||
return logger
|
||||
|
||||
def bytes_output(path: Optional[Path]=None,
|
||||
default: OutputFile=sys.stdout,
|
||||
mode: str='w',
|
||||
) -> BinaryIO:
|
||||
"""Get a file-like object suitable for binary output
|
||||
|
||||
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
|
||||
``default``. If ``default`` is a file descriptor or text IO object, this
|
||||
method returns a file-like object that writes to the same place.
|
||||
|
||||
Otherwise, returns ``path.open(mode)``.
|
||||
"""
|
||||
mode = f'{mode}b'
|
||||
if path is None or path == STDSTREAM_PATH:
|
||||
if isinstance(default, int):
|
||||
retval = open(default, mode)
|
||||
elif isinstance(default, TextIO):
|
||||
retval = default.buffer
|
||||
else:
|
||||
retval = default
|
||||
else:
|
||||
retval = path.open(mode)
|
||||
return cast(BinaryIO, retval)
|
||||
|
||||
def text_output(path: Optional[Path]=None,
|
||||
default: OutputFile=sys.stdout,
|
||||
mode: str='w',
|
||||
encoding: Optional[str]=None,
|
||||
) -> TextIO:
|
||||
"""Get a file-like object suitable for text output
|
||||
|
||||
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
|
||||
``default``. If ``default`` is a file descriptor or binary IO object, this
|
||||
method returns a file-like object that writes to the same place.
|
||||
|
||||
Otherwise, returns ``path.open(mode)``.
|
||||
"""
|
||||
if path is None or path == STDSTREAM_PATH:
|
||||
if isinstance(default, int):
|
||||
retval = open(default, mode, encoding=encoding)
|
||||
elif isinstance(default, BinaryIO):
|
||||
retval = io.TextIOWrapper(default, encoding=encoding)
|
||||
else:
|
||||
retval = default
|
||||
else:
|
||||
retval = path.open(mode, encoding=encoding)
|
||||
return cast(TextIO, retval)
|
||||
|
|
|
@ -120,7 +120,6 @@ from .. import filters
|
|||
from .. import rtutil
|
||||
|
||||
PROGNAME = 'accrual-report'
|
||||
STANDARD_PATH = Path('-')
|
||||
|
||||
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
|
||||
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
|
||||
|
@ -677,28 +676,6 @@ metadata to match. A single ticket number is a shortcut for
|
|||
args.report_type = ReportType.AGING
|
||||
return args
|
||||
|
||||
def get_output_path(output_path: Optional[Path],
|
||||
default_path: Path=STANDARD_PATH,
|
||||
) -> Optional[Path]:
|
||||
if output_path is None:
|
||||
output_path = default_path
|
||||
if output_path == STANDARD_PATH:
|
||||
return None
|
||||
else:
|
||||
return output_path
|
||||
|
||||
def get_output_bin(path: Optional[Path], stdout: TextIO) -> BinaryIO:
|
||||
if path is None:
|
||||
return open(stdout.fileno(), 'wb')
|
||||
else:
|
||||
return path.open('wb')
|
||||
|
||||
def get_output_text(path: Optional[Path], stdout: TextIO) -> TextIO:
|
||||
if path is None:
|
||||
return stdout
|
||||
else:
|
||||
return path.open('w')
|
||||
|
||||
def main(arglist: Optional[Sequence[str]]=None,
|
||||
stdout: TextIO=sys.stdout,
|
||||
stderr: TextIO=sys.stderr,
|
||||
|
@ -762,29 +739,26 @@ def main(arglist: Optional[Sequence[str]]=None,
|
|||
logger.error("unable to generate aging report: RT client is required")
|
||||
else:
|
||||
now = datetime.datetime.now()
|
||||
default_path = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
|
||||
output_path = get_output_path(args.output_file, default_path)
|
||||
out_bin = get_output_bin(output_path, stdout)
|
||||
if args.output_file is None:
|
||||
args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
|
||||
logger.info("Writing report to %s", args.output_file)
|
||||
out_bin = cliutil.bytes_output(args.output_file, stdout)
|
||||
report = AgingReport(rt_client, out_bin)
|
||||
elif args.report_type is ReportType.OUTGOING:
|
||||
rt_client = config.rt_client()
|
||||
if rt_client is None:
|
||||
logger.error("unable to generate outgoing report: RT client is required")
|
||||
else:
|
||||
output_path = get_output_path(args.output_file)
|
||||
out_file = get_output_text(output_path, stdout)
|
||||
out_file = cliutil.text_output(args.output_file, stdout)
|
||||
report = OutgoingReport(rt_client, out_file)
|
||||
else:
|
||||
output_path = get_output_path(args.output_file)
|
||||
out_file = get_output_text(output_path, stdout)
|
||||
out_file = cliutil.text_output(args.output_file, stdout)
|
||||
report = args.report_type.value(out_file)
|
||||
|
||||
if report is None:
|
||||
returncode |= ReturnFlag.REPORT_ERRORS
|
||||
else:
|
||||
report.run(groups)
|
||||
if args.output_file != output_path:
|
||||
logger.info("Report saved to %s", output_path)
|
||||
return 0 if returncode == 0 else 16 + returncode
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -21,12 +21,18 @@ import inspect
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from conservancy_beancount import cliutil
|
||||
|
||||
FILE_NAMES = ['-foobar', '-foo.bin']
|
||||
STREAM_PATHS = [None, Path('-')]
|
||||
|
||||
class AlwaysEqual:
|
||||
def __eq__(self, other):
|
||||
return True
|
||||
|
@ -57,6 +63,45 @@ def argparser():
|
|||
cliutil.add_version_argument(parser)
|
||||
return parser
|
||||
|
||||
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
||||
def test_bytes_output_path(path_name, tmp_path):
|
||||
path = tmp_path / path_name
|
||||
stream = io.BytesIO()
|
||||
actual = cliutil.bytes_output(path, stream)
|
||||
assert actual is not stream
|
||||
assert str(actual.name) == str(path)
|
||||
assert 'w' in actual.mode
|
||||
assert 'b' in actual.mode
|
||||
|
||||
@pytest.mark.parametrize('path', STREAM_PATHS)
|
||||
def test_bytes_output_stream(path):
|
||||
stream = io.BytesIO()
|
||||
actual = cliutil.bytes_output(path, stream)
|
||||
assert actual is stream
|
||||
|
||||
@pytest.mark.parametrize('func_name', [
|
||||
'bytes_output',
|
||||
'text_output',
|
||||
])
|
||||
def test_default_output(func_name):
|
||||
actual = getattr(cliutil, func_name)()
|
||||
assert actual.fileno() == sys.stdout.fileno()
|
||||
|
||||
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
||||
def test_text_output_path(path_name, tmp_path):
|
||||
path = tmp_path / path_name
|
||||
stream = io.StringIO()
|
||||
actual = cliutil.text_output(path, stream)
|
||||
assert actual is not stream
|
||||
assert str(actual.name) == str(path)
|
||||
assert 'w' in actual.mode
|
||||
|
||||
@pytest.mark.parametrize('path', STREAM_PATHS)
|
||||
def test_text_output_stream(path):
|
||||
stream = io.StringIO()
|
||||
actual = cliutil.text_output(path, stream)
|
||||
assert actual is stream
|
||||
|
||||
@pytest.mark.parametrize('errnum', [
|
||||
errno.EACCES,
|
||||
errno.EPERM,
|
||||
|
|
Loading…
Reference in a new issue