cliutil: Add bytes_output() and text_output() functions.

This commit is contained in:
Brett Smith 2020-06-06 13:32:59 -04:00
parent 04c804a506
commit 2b5cb0eca6
3 changed files with 106 additions and 32 deletions

View file

@ -19,6 +19,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>."""
import argparse
import enum
import inspect
import io
import logging
import operator
import os
@ -36,8 +37,11 @@ from . import filters
from . import rtutil
from typing import (
cast,
Any,
BinaryIO,
Callable,
IO,
Iterable,
NamedTuple,
NoReturn,
@ -51,6 +55,9 @@ from .beancount_types import (
MetaKey,
)
OutputFile = Union[int, IO]
STDSTREAM_PATH = Path('-')
VERSION = pkg_resources.require(PKGNAME)[0].version
class ExceptHook:
@ -247,3 +254,51 @@ def setup_logger(logger: Union[str, logging.Logger]='',
logger.addHandler(handler)
logger.setLevel(loglevel)
return logger
def bytes_output(path: Optional[Path]=None,
default: OutputFile=sys.stdout,
mode: str='w',
) -> BinaryIO:
"""Get a file-like object suitable for binary output
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
``default``. If ``default`` is a file descriptor or text IO object, this
method returns a file-like object that writes to the same place.
Otherwise, returns ``path.open(mode)``.
"""
mode = f'{mode}b'
if path is None or path == STDSTREAM_PATH:
if isinstance(default, int):
retval = open(default, mode)
elif isinstance(default, TextIO):
retval = default.buffer
else:
retval = default
else:
retval = path.open(mode)
return cast(BinaryIO, retval)
def text_output(path: Optional[Path]=None,
default: OutputFile=sys.stdout,
mode: str='w',
encoding: Optional[str]=None,
) -> TextIO:
"""Get a file-like object suitable for text output
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
``default``. If ``default`` is a file descriptor or binary IO object, this
method returns a file-like object that writes to the same place.
Otherwise, returns ``path.open(mode)``.
"""
if path is None or path == STDSTREAM_PATH:
if isinstance(default, int):
retval = open(default, mode, encoding=encoding)
elif isinstance(default, BinaryIO):
retval = io.TextIOWrapper(default, encoding=encoding)
else:
retval = default
else:
retval = path.open(mode, encoding=encoding)
return cast(TextIO, retval)

View file

@ -120,7 +120,6 @@ from .. import filters
from .. import rtutil
PROGNAME = 'accrual-report'
STANDARD_PATH = Path('-')
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
@ -677,28 +676,6 @@ metadata to match. A single ticket number is a shortcut for
args.report_type = ReportType.AGING
return args
def get_output_path(output_path: Optional[Path],
default_path: Path=STANDARD_PATH,
) -> Optional[Path]:
if output_path is None:
output_path = default_path
if output_path == STANDARD_PATH:
return None
else:
return output_path
def get_output_bin(path: Optional[Path], stdout: TextIO) -> BinaryIO:
if path is None:
return open(stdout.fileno(), 'wb')
else:
return path.open('wb')
def get_output_text(path: Optional[Path], stdout: TextIO) -> TextIO:
if path is None:
return stdout
else:
return path.open('w')
def main(arglist: Optional[Sequence[str]]=None,
stdout: TextIO=sys.stdout,
stderr: TextIO=sys.stderr,
@ -762,29 +739,26 @@ def main(arglist: Optional[Sequence[str]]=None,
logger.error("unable to generate aging report: RT client is required")
else:
now = datetime.datetime.now()
default_path = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
output_path = get_output_path(args.output_file, default_path)
out_bin = get_output_bin(output_path, stdout)
if args.output_file is None:
args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
logger.info("Writing report to %s", args.output_file)
out_bin = cliutil.bytes_output(args.output_file, stdout)
report = AgingReport(rt_client, out_bin)
elif args.report_type is ReportType.OUTGOING:
rt_client = config.rt_client()
if rt_client is None:
logger.error("unable to generate outgoing report: RT client is required")
else:
output_path = get_output_path(args.output_file)
out_file = get_output_text(output_path, stdout)
out_file = cliutil.text_output(args.output_file, stdout)
report = OutgoingReport(rt_client, out_file)
else:
output_path = get_output_path(args.output_file)
out_file = get_output_text(output_path, stdout)
out_file = cliutil.text_output(args.output_file, stdout)
report = args.report_type.value(out_file)
if report is None:
returncode |= ReturnFlag.REPORT_ERRORS
else:
report.run(groups)
if args.output_file != output_path:
logger.info("Report saved to %s", output_path)
return 0 if returncode == 0 else 16 + returncode
if __name__ == '__main__':

View file

@ -21,12 +21,18 @@ import inspect
import logging
import os
import re
import sys
import traceback
import pytest
from pathlib import Path
from conservancy_beancount import cliutil
FILE_NAMES = ['-foobar', '-foo.bin']
STREAM_PATHS = [None, Path('-')]
class AlwaysEqual:
def __eq__(self, other):
return True
@ -57,6 +63,45 @@ def argparser():
cliutil.add_version_argument(parser)
return parser
@pytest.mark.parametrize('path_name', FILE_NAMES)
def test_bytes_output_path(path_name, tmp_path):
path = tmp_path / path_name
stream = io.BytesIO()
actual = cliutil.bytes_output(path, stream)
assert actual is not stream
assert str(actual.name) == str(path)
assert 'w' in actual.mode
assert 'b' in actual.mode
@pytest.mark.parametrize('path', STREAM_PATHS)
def test_bytes_output_stream(path):
stream = io.BytesIO()
actual = cliutil.bytes_output(path, stream)
assert actual is stream
@pytest.mark.parametrize('func_name', [
'bytes_output',
'text_output',
])
def test_default_output(func_name):
actual = getattr(cliutil, func_name)()
assert actual.fileno() == sys.stdout.fileno()
@pytest.mark.parametrize('path_name', FILE_NAMES)
def test_text_output_path(path_name, tmp_path):
path = tmp_path / path_name
stream = io.StringIO()
actual = cliutil.text_output(path, stream)
assert actual is not stream
assert str(actual.name) == str(path)
assert 'w' in actual.mode
@pytest.mark.parametrize('path', STREAM_PATHS)
def test_text_output_stream(path):
stream = io.StringIO()
actual = cliutil.text_output(path, stream)
assert actual is stream
@pytest.mark.parametrize('errnum', [
errno.EACCES,
errno.EPERM,