cliutil: Add bytes_output() and text_output() functions.
This commit is contained in:
parent
04c804a506
commit
2b5cb0eca6
3 changed files with 106 additions and 32 deletions
|
@ -19,6 +19,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>."""
|
||||||
import argparse
|
import argparse
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
|
@ -36,8 +37,11 @@ from . import filters
|
||||||
from . import rtutil
|
from . import rtutil
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
cast,
|
||||||
Any,
|
Any,
|
||||||
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
|
IO,
|
||||||
Iterable,
|
Iterable,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
|
@ -51,6 +55,9 @@ from .beancount_types import (
|
||||||
MetaKey,
|
MetaKey,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
OutputFile = Union[int, IO]
|
||||||
|
|
||||||
|
STDSTREAM_PATH = Path('-')
|
||||||
VERSION = pkg_resources.require(PKGNAME)[0].version
|
VERSION = pkg_resources.require(PKGNAME)[0].version
|
||||||
|
|
||||||
class ExceptHook:
|
class ExceptHook:
|
||||||
|
@ -247,3 +254,51 @@ def setup_logger(logger: Union[str, logging.Logger]='',
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
logger.setLevel(loglevel)
|
logger.setLevel(loglevel)
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
def bytes_output(path: Optional[Path]=None,
|
||||||
|
default: OutputFile=sys.stdout,
|
||||||
|
mode: str='w',
|
||||||
|
) -> BinaryIO:
|
||||||
|
"""Get a file-like object suitable for binary output
|
||||||
|
|
||||||
|
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
|
||||||
|
``default``. If ``default`` is a file descriptor or text IO object, this
|
||||||
|
method returns a file-like object that writes to the same place.
|
||||||
|
|
||||||
|
Otherwise, returns ``path.open(mode)``.
|
||||||
|
"""
|
||||||
|
mode = f'{mode}b'
|
||||||
|
if path is None or path == STDSTREAM_PATH:
|
||||||
|
if isinstance(default, int):
|
||||||
|
retval = open(default, mode)
|
||||||
|
elif isinstance(default, TextIO):
|
||||||
|
retval = default.buffer
|
||||||
|
else:
|
||||||
|
retval = default
|
||||||
|
else:
|
||||||
|
retval = path.open(mode)
|
||||||
|
return cast(BinaryIO, retval)
|
||||||
|
|
||||||
|
def text_output(path: Optional[Path]=None,
|
||||||
|
default: OutputFile=sys.stdout,
|
||||||
|
mode: str='w',
|
||||||
|
encoding: Optional[str]=None,
|
||||||
|
) -> TextIO:
|
||||||
|
"""Get a file-like object suitable for text output
|
||||||
|
|
||||||
|
If ``path`` is ``None`` or ``-``, returns a file-like object backed by
|
||||||
|
``default``. If ``default`` is a file descriptor or binary IO object, this
|
||||||
|
method returns a file-like object that writes to the same place.
|
||||||
|
|
||||||
|
Otherwise, returns ``path.open(mode)``.
|
||||||
|
"""
|
||||||
|
if path is None or path == STDSTREAM_PATH:
|
||||||
|
if isinstance(default, int):
|
||||||
|
retval = open(default, mode, encoding=encoding)
|
||||||
|
elif isinstance(default, BinaryIO):
|
||||||
|
retval = io.TextIOWrapper(default, encoding=encoding)
|
||||||
|
else:
|
||||||
|
retval = default
|
||||||
|
else:
|
||||||
|
retval = path.open(mode, encoding=encoding)
|
||||||
|
return cast(TextIO, retval)
|
||||||
|
|
|
@ -120,7 +120,6 @@ from .. import filters
|
||||||
from .. import rtutil
|
from .. import rtutil
|
||||||
|
|
||||||
PROGNAME = 'accrual-report'
|
PROGNAME = 'accrual-report'
|
||||||
STANDARD_PATH = Path('-')
|
|
||||||
|
|
||||||
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
|
CompoundAmount = TypeVar('CompoundAmount', data.Amount, core.Balance)
|
||||||
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
|
PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
|
||||||
|
@ -677,28 +676,6 @@ metadata to match. A single ticket number is a shortcut for
|
||||||
args.report_type = ReportType.AGING
|
args.report_type = ReportType.AGING
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_output_path(output_path: Optional[Path],
|
|
||||||
default_path: Path=STANDARD_PATH,
|
|
||||||
) -> Optional[Path]:
|
|
||||||
if output_path is None:
|
|
||||||
output_path = default_path
|
|
||||||
if output_path == STANDARD_PATH:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
def get_output_bin(path: Optional[Path], stdout: TextIO) -> BinaryIO:
|
|
||||||
if path is None:
|
|
||||||
return open(stdout.fileno(), 'wb')
|
|
||||||
else:
|
|
||||||
return path.open('wb')
|
|
||||||
|
|
||||||
def get_output_text(path: Optional[Path], stdout: TextIO) -> TextIO:
|
|
||||||
if path is None:
|
|
||||||
return stdout
|
|
||||||
else:
|
|
||||||
return path.open('w')
|
|
||||||
|
|
||||||
def main(arglist: Optional[Sequence[str]]=None,
|
def main(arglist: Optional[Sequence[str]]=None,
|
||||||
stdout: TextIO=sys.stdout,
|
stdout: TextIO=sys.stdout,
|
||||||
stderr: TextIO=sys.stderr,
|
stderr: TextIO=sys.stderr,
|
||||||
|
@ -762,29 +739,26 @@ def main(arglist: Optional[Sequence[str]]=None,
|
||||||
logger.error("unable to generate aging report: RT client is required")
|
logger.error("unable to generate aging report: RT client is required")
|
||||||
else:
|
else:
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
default_path = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
|
if args.output_file is None:
|
||||||
output_path = get_output_path(args.output_file, default_path)
|
args.output_file = Path(now.strftime('AgingReport_%Y-%m-%d_%H:%M.ods'))
|
||||||
out_bin = get_output_bin(output_path, stdout)
|
logger.info("Writing report to %s", args.output_file)
|
||||||
|
out_bin = cliutil.bytes_output(args.output_file, stdout)
|
||||||
report = AgingReport(rt_client, out_bin)
|
report = AgingReport(rt_client, out_bin)
|
||||||
elif args.report_type is ReportType.OUTGOING:
|
elif args.report_type is ReportType.OUTGOING:
|
||||||
rt_client = config.rt_client()
|
rt_client = config.rt_client()
|
||||||
if rt_client is None:
|
if rt_client is None:
|
||||||
logger.error("unable to generate outgoing report: RT client is required")
|
logger.error("unable to generate outgoing report: RT client is required")
|
||||||
else:
|
else:
|
||||||
output_path = get_output_path(args.output_file)
|
out_file = cliutil.text_output(args.output_file, stdout)
|
||||||
out_file = get_output_text(output_path, stdout)
|
|
||||||
report = OutgoingReport(rt_client, out_file)
|
report = OutgoingReport(rt_client, out_file)
|
||||||
else:
|
else:
|
||||||
output_path = get_output_path(args.output_file)
|
out_file = cliutil.text_output(args.output_file, stdout)
|
||||||
out_file = get_output_text(output_path, stdout)
|
|
||||||
report = args.report_type.value(out_file)
|
report = args.report_type.value(out_file)
|
||||||
|
|
||||||
if report is None:
|
if report is None:
|
||||||
returncode |= ReturnFlag.REPORT_ERRORS
|
returncode |= ReturnFlag.REPORT_ERRORS
|
||||||
else:
|
else:
|
||||||
report.run(groups)
|
report.run(groups)
|
||||||
if args.output_file != output_path:
|
|
||||||
logger.info("Report saved to %s", output_path)
|
|
||||||
return 0 if returncode == 0 else 16 + returncode
|
return 0 if returncode == 0 else 16 + returncode
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -21,12 +21,18 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from conservancy_beancount import cliutil
|
from conservancy_beancount import cliutil
|
||||||
|
|
||||||
|
FILE_NAMES = ['-foobar', '-foo.bin']
|
||||||
|
STREAM_PATHS = [None, Path('-')]
|
||||||
|
|
||||||
class AlwaysEqual:
|
class AlwaysEqual:
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return True
|
return True
|
||||||
|
@ -57,6 +63,45 @@ def argparser():
|
||||||
cliutil.add_version_argument(parser)
|
cliutil.add_version_argument(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
||||||
|
def test_bytes_output_path(path_name, tmp_path):
|
||||||
|
path = tmp_path / path_name
|
||||||
|
stream = io.BytesIO()
|
||||||
|
actual = cliutil.bytes_output(path, stream)
|
||||||
|
assert actual is not stream
|
||||||
|
assert str(actual.name) == str(path)
|
||||||
|
assert 'w' in actual.mode
|
||||||
|
assert 'b' in actual.mode
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('path', STREAM_PATHS)
|
||||||
|
def test_bytes_output_stream(path):
|
||||||
|
stream = io.BytesIO()
|
||||||
|
actual = cliutil.bytes_output(path, stream)
|
||||||
|
assert actual is stream
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('func_name', [
|
||||||
|
'bytes_output',
|
||||||
|
'text_output',
|
||||||
|
])
|
||||||
|
def test_default_output(func_name):
|
||||||
|
actual = getattr(cliutil, func_name)()
|
||||||
|
assert actual.fileno() == sys.stdout.fileno()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('path_name', FILE_NAMES)
|
||||||
|
def test_text_output_path(path_name, tmp_path):
|
||||||
|
path = tmp_path / path_name
|
||||||
|
stream = io.StringIO()
|
||||||
|
actual = cliutil.text_output(path, stream)
|
||||||
|
assert actual is not stream
|
||||||
|
assert str(actual.name) == str(path)
|
||||||
|
assert 'w' in actual.mode
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('path', STREAM_PATHS)
|
||||||
|
def test_text_output_stream(path):
|
||||||
|
stream = io.StringIO()
|
||||||
|
actual = cliutil.text_output(path, stream)
|
||||||
|
assert actual is stream
|
||||||
|
|
||||||
@pytest.mark.parametrize('errnum', [
|
@pytest.mark.parametrize('errnum', [
|
||||||
errno.EACCES,
|
errno.EACCES,
|
||||||
errno.EPERM,
|
errno.EPERM,
|
||||||
|
|
Loading…
Reference in a new issue