cliutil: Add EnumArgument.

This functionality already existed in the code three times, and it's about
to get more important for the ledger report, so now was the time to abstract
it.
This commit is contained in:
Brett Smith 2021-02-17 14:00:06 -05:00
parent fe3560b748
commit 4188dc6a64
3 changed files with 104 additions and 22 deletions

View file

@ -39,6 +39,8 @@ from typing import (
BinaryIO,
Callable,
Container,
Generic,
Hashable,
IO,
Iterable,
NamedTuple,
@ -47,18 +49,75 @@ from typing import (
Sequence,
TextIO,
Type,
TypeVar,
Union,
)
from .beancount_types import (
MetaKey,
)
ET = TypeVar('ET', bound=enum.Enum)
OutputFile = Union[int, IO]
CPU_COUNT = len(os.sched_getaffinity(0))
STDSTREAM_PATH = Path('-')
VERSION = pkg_resources.require(PKGNAME)[0].version
class EnumArgument(Generic[ET]):
"""Wrapper class to use an enum as argument values
Use this class when the user can choose one of some arbitrary enum names
as an argument. It will let user abbreviate and use any case, and will
return the correct value as long as it's unambiguous. Typical usage
looks like::
enum_arg = EnumArgument(Enum)
arg_parser.add_argument(
'--choice',
type=enum_arg.enum_type, # or .value_type
help=f"Choices are {enum_arg.choices_str()}",
)
"""
# I originally wrote this as a mixin class, to eliminate the need for the
# explicit wrapping in the example above. But Python 3.6 doesn't really
# support mixins with Enums; see <https://bugs.python.org/issue29577>.
# This functionality could be moved to a mixin when we drop support for
# Python 3.6.
def __init__(self, base: Type[ET]) -> None:
self.base = base
def enum_type(self, arg: str) -> ET:
"""Return a single enum whose name matches the user argument"""
regexp = re.compile(re.escape(arg), re.IGNORECASE)
matches = frozenset(choice for choice in self.base if regexp.match(choice.name))
count = len(matches)
if count == 1:
return next(iter(matches))
elif count:
names = ', '.join(repr(choice.name) for choice in matches)
raise ValueError(f"ambiguous argument {arg!r}: matches {names}")
else:
raise ValueError(f"unknown argument {arg!r}")
def value_type(self, arg: str) -> Any:
return self.enum_type(arg).value
def choices_str(self, sep: str=', ', fmt: str='{!r}') -> str:
"""Return a user-formatted string of enum names"""
sortkey: Callable[[ET], Hashable] = getattr(
self.base, '_choices_sortkey', self._choices_sortkey,
)
return sep.join(
fmt.format(choice.name.lower())
for choice in sorted(self.base, key=sortkey)
)
def _choices_sortkey(self, choice: ET) -> Hashable:
return choice.name
class ExceptHook:
def __init__(self, logger: Optional[logging.Logger]=None) -> None:
if logger is None:
@ -148,17 +207,8 @@ class LogLevel(enum.IntEnum):
ERR = ERROR
CRIT = CRITICAL
@classmethod
def from_arg(cls, arg: str) -> int:
try:
return cls[arg.upper()].value
except KeyError:
raise ValueError(f"unknown loglevel {arg!r}") from None
@classmethod
def choices(cls) -> Iterable[str]:
for level in sorted(cls, key=operator.attrgetter('value')):
yield level.name.lower()
def _choices_sortkey(self) -> Hashable:
return self.value
class SearchTerm(NamedTuple):
@ -250,14 +300,15 @@ Can specify a positive integer or a percentage of CPU cores. Default all cores.
def add_loglevel_argument(parser: argparse.ArgumentParser,
default: LogLevel=LogLevel.INFO) -> argparse.Action:
arg_enum = EnumArgument(LogLevel)
return parser.add_argument(
'--loglevel',
metavar='LEVEL',
default=default.value,
type=LogLevel.from_arg,
type=arg_enum.value_type,
help="Show logs at this level and above."
f" Specify one of {', '.join(LogLevel.choices())}."
f" Default {default.name.lower()}.",
f" Specify one of {arg_enum.choices_str()}."
f" Default {default.name.lower()!r}.",
)
def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action:

View file

@ -306,13 +306,6 @@ class ReportType(enum.Enum):
TXT = TEXT
SPREADSHEET = ODS
@classmethod
def from_arg(cls, s: str) -> 'ReportType':
try:
return cls[s.upper()]
except KeyError:
raise ValueError(f"no report type matches {s!r}") from None
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog=PROGNAME)
@ -337,7 +330,7 @@ The default is a year after the start date.
parser.add_argument(
'--report-type', '-t',
metavar='TYPE',
type=ReportType.from_arg,
type=cliutil.EnumArgument(ReportType).enum_type,
help="""Type of report to generate. `text` gives a plain two-column text
report listing accounts and balances over the period, and is the default when
you search for a specific project/fund. `ods` produces a higher-level

View file

@ -7,6 +7,7 @@
import argparse
import datetime
import enum
import errno
import io
import inspect
@ -27,6 +28,12 @@ from conservancy_beancount import cliutil
FILE_NAMES = ['-foobar', '-foo.bin']
STREAM_PATHS = [None, Path('-')]
class ArgChoices(enum.Enum):
AA = 'aa'
AB = 'ab'
BB = 'bb'
class MockTraceback:
def __init__(self, stack=None, index=0):
if stack is None:
@ -45,6 +52,10 @@ class MockTraceback:
return None
@pytest.fixture(scope='module')
def arg_enum():
return cliutil.EnumArgument(ArgChoices)
@pytest.fixture(scope='module')
def argparser():
parser = argparse.ArgumentParser(prog='test_cliutil')
@ -239,3 +250,30 @@ def test_diff_year(date, diff, expected):
])
def test_can_run(cmd, expected):
assert cliutil.can_run(cmd) == expected
@pytest.mark.parametrize('choice', ArgChoices)
def test_enum_arg_enum_type(arg_enum, choice):
assert arg_enum.enum_type(choice.name) is choice
assert arg_enum.enum_type(choice.value) is choice
@pytest.mark.parametrize('arg', 'az\0')
def test_enum_arg_no_enum_match(arg_enum, arg):
with pytest.raises(ValueError):
arg_enum.enum_type(arg)
@pytest.mark.parametrize('choice', ArgChoices)
def test_enum_arg_value_type(arg_enum, choice):
assert arg_enum.value_type(choice.name) == choice.value
assert arg_enum.value_type(choice.value) == choice.value
@pytest.mark.parametrize('arg', 'az\0')
def test_enum_arg_no_value_match(arg_enum, arg):
with pytest.raises(ValueError):
arg_enum.value_type(arg)
def test_enum_arg_choices_str_defaults(arg_enum):
assert arg_enum.choices_str() == ', '.join(repr(c.value) for c in ArgChoices)
def test_enum_arg_choices_str_args(arg_enum):
sep = '/'
assert arg_enum.choices_str(sep, '{}') == sep.join(c.value for c in ArgChoices)