From 4188dc6a64a8702362aff744cc5ce1bf49e92254 Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Wed, 17 Feb 2021 14:00:06 -0500 Subject: [PATCH] cliutil: Add EnumArgument. This functionality already existed in the code three times, and it's about to get more important for the ledger report, so now was the time to abstract it. --- conservancy_beancount/cliutil.py | 79 ++++++++++++++++++++++----- conservancy_beancount/reports/fund.py | 9 +-- tests/test_cliutil.py | 38 +++++++++++++ 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/conservancy_beancount/cliutil.py b/conservancy_beancount/cliutil.py index 40d8ab9..358c568 100644 --- a/conservancy_beancount/cliutil.py +++ b/conservancy_beancount/cliutil.py @@ -39,6 +39,8 @@ from typing import ( BinaryIO, Callable, Container, + Generic, + Hashable, IO, Iterable, NamedTuple, @@ -47,18 +49,75 @@ from typing import ( Sequence, TextIO, Type, + TypeVar, Union, ) from .beancount_types import ( MetaKey, ) +ET = TypeVar('ET', bound=enum.Enum) OutputFile = Union[int, IO] CPU_COUNT = len(os.sched_getaffinity(0)) STDSTREAM_PATH = Path('-') VERSION = pkg_resources.require(PKGNAME)[0].version +class EnumArgument(Generic[ET]): + """Wrapper class to use an enum as argument values + + Use this class when the user can choose one of some arbitrary enum names + as an argument. It will let user abbreviate and use any case, and will + return the correct value as long as it's unambiguous. Typical usage + looks like:: + + enum_arg = EnumArgument(Enum) + arg_parser.add_argument( + '--choice', + type=enum_arg.enum_type, # or .value_type + help=f"Choices are {enum_arg.choices_str()}", + … + ) + """ + # I originally wrote this as a mixin class, to eliminate the need for the + # explicit wrapping in the example above. But Python 3.6 doesn't really + # support mixins with Enums; see . + # This functionality could be moved to a mixin when we drop support for + # Python 3.6. + + def __init__(self, base: Type[ET]) -> None: + self.base = base + + def enum_type(self, arg: str) -> ET: + """Return a single enum whose name matches the user argument""" + regexp = re.compile(re.escape(arg), re.IGNORECASE) + matches = frozenset(choice for choice in self.base if regexp.match(choice.name)) + count = len(matches) + if count == 1: + return next(iter(matches)) + elif count: + names = ', '.join(repr(choice.name) for choice in matches) + raise ValueError(f"ambiguous argument {arg!r}: matches {names}") + else: + raise ValueError(f"unknown argument {arg!r}") + + def value_type(self, arg: str) -> Any: + return self.enum_type(arg).value + + def choices_str(self, sep: str=', ', fmt: str='{!r}') -> str: + """Return a user-formatted string of enum names""" + sortkey: Callable[[ET], Hashable] = getattr( + self.base, '_choices_sortkey', self._choices_sortkey, + ) + return sep.join( + fmt.format(choice.name.lower()) + for choice in sorted(self.base, key=sortkey) + ) + + def _choices_sortkey(self, choice: ET) -> Hashable: + return choice.name + + class ExceptHook: def __init__(self, logger: Optional[logging.Logger]=None) -> None: if logger is None: @@ -148,17 +207,8 @@ class LogLevel(enum.IntEnum): ERR = ERROR CRIT = CRITICAL - @classmethod - def from_arg(cls, arg: str) -> int: - try: - return cls[arg.upper()].value - except KeyError: - raise ValueError(f"unknown loglevel {arg!r}") from None - - @classmethod - def choices(cls) -> Iterable[str]: - for level in sorted(cls, key=operator.attrgetter('value')): - yield level.name.lower() + def _choices_sortkey(self) -> Hashable: + return self.value class SearchTerm(NamedTuple): @@ -250,14 +300,15 @@ Can specify a positive integer or a percentage of CPU cores. Default all cores. def add_loglevel_argument(parser: argparse.ArgumentParser, default: LogLevel=LogLevel.INFO) -> argparse.Action: + arg_enum = EnumArgument(LogLevel) return parser.add_argument( '--loglevel', metavar='LEVEL', default=default.value, - type=LogLevel.from_arg, + type=arg_enum.value_type, help="Show logs at this level and above." - f" Specify one of {', '.join(LogLevel.choices())}." - f" Default {default.name.lower()}.", + f" Specify one of {arg_enum.choices_str()}." + f" Default {default.name.lower()!r}.", ) def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action: diff --git a/conservancy_beancount/reports/fund.py b/conservancy_beancount/reports/fund.py index 872d61e..637ac34 100644 --- a/conservancy_beancount/reports/fund.py +++ b/conservancy_beancount/reports/fund.py @@ -306,13 +306,6 @@ class ReportType(enum.Enum): TXT = TEXT SPREADSHEET = ODS - @classmethod - def from_arg(cls, s: str) -> 'ReportType': - try: - return cls[s.upper()] - except KeyError: - raise ValueError(f"no report type matches {s!r}") from None - def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: parser = argparse.ArgumentParser(prog=PROGNAME) @@ -337,7 +330,7 @@ The default is a year after the start date. parser.add_argument( '--report-type', '-t', metavar='TYPE', - type=ReportType.from_arg, + type=cliutil.EnumArgument(ReportType).enum_type, help="""Type of report to generate. `text` gives a plain two-column text report listing accounts and balances over the period, and is the default when you search for a specific project/fund. `ods` produces a higher-level diff --git a/tests/test_cliutil.py b/tests/test_cliutil.py index 20ea61e..de95f11 100644 --- a/tests/test_cliutil.py +++ b/tests/test_cliutil.py @@ -7,6 +7,7 @@ import argparse import datetime +import enum import errno import io import inspect @@ -27,6 +28,12 @@ from conservancy_beancount import cliutil FILE_NAMES = ['-foobar', '-foo.bin'] STREAM_PATHS = [None, Path('-')] +class ArgChoices(enum.Enum): + AA = 'aa' + AB = 'ab' + BB = 'bb' + + class MockTraceback: def __init__(self, stack=None, index=0): if stack is None: @@ -45,6 +52,10 @@ class MockTraceback: return None +@pytest.fixture(scope='module') +def arg_enum(): + return cliutil.EnumArgument(ArgChoices) + @pytest.fixture(scope='module') def argparser(): parser = argparse.ArgumentParser(prog='test_cliutil') @@ -239,3 +250,30 @@ def test_diff_year(date, diff, expected): ]) def test_can_run(cmd, expected): assert cliutil.can_run(cmd) == expected + +@pytest.mark.parametrize('choice', ArgChoices) +def test_enum_arg_enum_type(arg_enum, choice): + assert arg_enum.enum_type(choice.name) is choice + assert arg_enum.enum_type(choice.value) is choice + +@pytest.mark.parametrize('arg', 'az\0') +def test_enum_arg_no_enum_match(arg_enum, arg): + with pytest.raises(ValueError): + arg_enum.enum_type(arg) + +@pytest.mark.parametrize('choice', ArgChoices) +def test_enum_arg_value_type(arg_enum, choice): + assert arg_enum.value_type(choice.name) == choice.value + assert arg_enum.value_type(choice.value) == choice.value + +@pytest.mark.parametrize('arg', 'az\0') +def test_enum_arg_no_value_match(arg_enum, arg): + with pytest.raises(ValueError): + arg_enum.value_type(arg) + +def test_enum_arg_choices_str_defaults(arg_enum): + assert arg_enum.choices_str() == ', '.join(repr(c.value) for c in ArgChoices) + +def test_enum_arg_choices_str_args(arg_enum): + sep = '/' + assert arg_enum.choices_str(sep, '{}') == sep.join(c.value for c in ArgChoices)