cliutil: Add EnumArgument.
This functionality already existed in the code three times, and it's about to get more important for the ledger report, so now was the time to abstract it.
This commit is contained in:
parent
fe3560b748
commit
4188dc6a64
3 changed files with 104 additions and 22 deletions
|
@ -39,6 +39,8 @@ from typing import (
|
|||
BinaryIO,
|
||||
Callable,
|
||||
Container,
|
||||
Generic,
|
||||
Hashable,
|
||||
IO,
|
||||
Iterable,
|
||||
NamedTuple,
|
||||
|
@ -47,18 +49,75 @@ from typing import (
|
|||
Sequence,
|
||||
TextIO,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from .beancount_types import (
|
||||
MetaKey,
|
||||
)
|
||||
|
||||
ET = TypeVar('ET', bound=enum.Enum)
|
||||
OutputFile = Union[int, IO]
|
||||
|
||||
CPU_COUNT = len(os.sched_getaffinity(0))
|
||||
STDSTREAM_PATH = Path('-')
|
||||
VERSION = pkg_resources.require(PKGNAME)[0].version
|
||||
|
||||
class EnumArgument(Generic[ET]):
|
||||
"""Wrapper class to use an enum as argument values
|
||||
|
||||
Use this class when the user can choose one of some arbitrary enum names
|
||||
as an argument. It will let user abbreviate and use any case, and will
|
||||
return the correct value as long as it's unambiguous. Typical usage
|
||||
looks like::
|
||||
|
||||
enum_arg = EnumArgument(Enum)
|
||||
arg_parser.add_argument(
|
||||
'--choice',
|
||||
type=enum_arg.enum_type, # or .value_type
|
||||
help=f"Choices are {enum_arg.choices_str()}",
|
||||
…
|
||||
)
|
||||
"""
|
||||
# I originally wrote this as a mixin class, to eliminate the need for the
|
||||
# explicit wrapping in the example above. But Python 3.6 doesn't really
|
||||
# support mixins with Enums; see <https://bugs.python.org/issue29577>.
|
||||
# This functionality could be moved to a mixin when we drop support for
|
||||
# Python 3.6.
|
||||
|
||||
def __init__(self, base: Type[ET]) -> None:
|
||||
self.base = base
|
||||
|
||||
def enum_type(self, arg: str) -> ET:
|
||||
"""Return a single enum whose name matches the user argument"""
|
||||
regexp = re.compile(re.escape(arg), re.IGNORECASE)
|
||||
matches = frozenset(choice for choice in self.base if regexp.match(choice.name))
|
||||
count = len(matches)
|
||||
if count == 1:
|
||||
return next(iter(matches))
|
||||
elif count:
|
||||
names = ', '.join(repr(choice.name) for choice in matches)
|
||||
raise ValueError(f"ambiguous argument {arg!r}: matches {names}")
|
||||
else:
|
||||
raise ValueError(f"unknown argument {arg!r}")
|
||||
|
||||
def value_type(self, arg: str) -> Any:
|
||||
return self.enum_type(arg).value
|
||||
|
||||
def choices_str(self, sep: str=', ', fmt: str='{!r}') -> str:
|
||||
"""Return a user-formatted string of enum names"""
|
||||
sortkey: Callable[[ET], Hashable] = getattr(
|
||||
self.base, '_choices_sortkey', self._choices_sortkey,
|
||||
)
|
||||
return sep.join(
|
||||
fmt.format(choice.name.lower())
|
||||
for choice in sorted(self.base, key=sortkey)
|
||||
)
|
||||
|
||||
def _choices_sortkey(self, choice: ET) -> Hashable:
|
||||
return choice.name
|
||||
|
||||
|
||||
class ExceptHook:
|
||||
def __init__(self, logger: Optional[logging.Logger]=None) -> None:
|
||||
if logger is None:
|
||||
|
@ -148,17 +207,8 @@ class LogLevel(enum.IntEnum):
|
|||
ERR = ERROR
|
||||
CRIT = CRITICAL
|
||||
|
||||
@classmethod
|
||||
def from_arg(cls, arg: str) -> int:
|
||||
try:
|
||||
return cls[arg.upper()].value
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown loglevel {arg!r}") from None
|
||||
|
||||
@classmethod
|
||||
def choices(cls) -> Iterable[str]:
|
||||
for level in sorted(cls, key=operator.attrgetter('value')):
|
||||
yield level.name.lower()
|
||||
def _choices_sortkey(self) -> Hashable:
|
||||
return self.value
|
||||
|
||||
|
||||
class SearchTerm(NamedTuple):
|
||||
|
@ -250,14 +300,15 @@ Can specify a positive integer or a percentage of CPU cores. Default all cores.
|
|||
|
||||
def add_loglevel_argument(parser: argparse.ArgumentParser,
|
||||
default: LogLevel=LogLevel.INFO) -> argparse.Action:
|
||||
arg_enum = EnumArgument(LogLevel)
|
||||
return parser.add_argument(
|
||||
'--loglevel',
|
||||
metavar='LEVEL',
|
||||
default=default.value,
|
||||
type=LogLevel.from_arg,
|
||||
type=arg_enum.value_type,
|
||||
help="Show logs at this level and above."
|
||||
f" Specify one of {', '.join(LogLevel.choices())}."
|
||||
f" Default {default.name.lower()}.",
|
||||
f" Specify one of {arg_enum.choices_str()}."
|
||||
f" Default {default.name.lower()!r}.",
|
||||
)
|
||||
|
||||
def add_rewrite_rules_argument(parser: argparse.ArgumentParser) -> argparse.Action:
|
||||
|
|
|
@ -306,13 +306,6 @@ class ReportType(enum.Enum):
|
|||
TXT = TEXT
|
||||
SPREADSHEET = ODS
|
||||
|
||||
@classmethod
|
||||
def from_arg(cls, s: str) -> 'ReportType':
|
||||
try:
|
||||
return cls[s.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(f"no report type matches {s!r}") from None
|
||||
|
||||
|
||||
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(prog=PROGNAME)
|
||||
|
@ -337,7 +330,7 @@ The default is a year after the start date.
|
|||
parser.add_argument(
|
||||
'--report-type', '-t',
|
||||
metavar='TYPE',
|
||||
type=ReportType.from_arg,
|
||||
type=cliutil.EnumArgument(ReportType).enum_type,
|
||||
help="""Type of report to generate. `text` gives a plain two-column text
|
||||
report listing accounts and balances over the period, and is the default when
|
||||
you search for a specific project/fund. `ods` produces a higher-level
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import argparse
|
||||
import datetime
|
||||
import enum
|
||||
import errno
|
||||
import io
|
||||
import inspect
|
||||
|
@ -27,6 +28,12 @@ from conservancy_beancount import cliutil
|
|||
FILE_NAMES = ['-foobar', '-foo.bin']
|
||||
STREAM_PATHS = [None, Path('-')]
|
||||
|
||||
class ArgChoices(enum.Enum):
|
||||
AA = 'aa'
|
||||
AB = 'ab'
|
||||
BB = 'bb'
|
||||
|
||||
|
||||
class MockTraceback:
|
||||
def __init__(self, stack=None, index=0):
|
||||
if stack is None:
|
||||
|
@ -45,6 +52,10 @@ class MockTraceback:
|
|||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def arg_enum():
|
||||
return cliutil.EnumArgument(ArgChoices)
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def argparser():
|
||||
parser = argparse.ArgumentParser(prog='test_cliutil')
|
||||
|
@ -239,3 +250,30 @@ def test_diff_year(date, diff, expected):
|
|||
])
|
||||
def test_can_run(cmd, expected):
|
||||
assert cliutil.can_run(cmd) == expected
|
||||
|
||||
@pytest.mark.parametrize('choice', ArgChoices)
|
||||
def test_enum_arg_enum_type(arg_enum, choice):
|
||||
assert arg_enum.enum_type(choice.name) is choice
|
||||
assert arg_enum.enum_type(choice.value) is choice
|
||||
|
||||
@pytest.mark.parametrize('arg', 'az\0')
|
||||
def test_enum_arg_no_enum_match(arg_enum, arg):
|
||||
with pytest.raises(ValueError):
|
||||
arg_enum.enum_type(arg)
|
||||
|
||||
@pytest.mark.parametrize('choice', ArgChoices)
|
||||
def test_enum_arg_value_type(arg_enum, choice):
|
||||
assert arg_enum.value_type(choice.name) == choice.value
|
||||
assert arg_enum.value_type(choice.value) == choice.value
|
||||
|
||||
@pytest.mark.parametrize('arg', 'az\0')
|
||||
def test_enum_arg_no_value_match(arg_enum, arg):
|
||||
with pytest.raises(ValueError):
|
||||
arg_enum.value_type(arg)
|
||||
|
||||
def test_enum_arg_choices_str_defaults(arg_enum):
|
||||
assert arg_enum.choices_str() == ', '.join(repr(c.value) for c in ArgChoices)
|
||||
|
||||
def test_enum_arg_choices_str_args(arg_enum):
|
||||
sep = '/'
|
||||
assert arg_enum.choices_str(sep, '{}') == sep.join(c.value for c in ArgChoices)
|
||||
|
|
Loading…
Reference in a new issue