From ee038d7b7df38f10dc662d129cd028bdc3e85508 Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Sun, 8 Mar 2020 18:24:51 -0400 Subject: [PATCH] plugin.core: Add type hints. --- conservancy_beancount/_typing.py | 58 ++++++++++++++++ conservancy_beancount/plugin/core.py | 98 ++++++++++++++++++++-------- 2 files changed, 130 insertions(+), 26 deletions(-) create mode 100644 conservancy_beancount/_typing.py diff --git a/conservancy_beancount/_typing.py b/conservancy_beancount/_typing.py new file mode 100644 index 0000000..fe22bf8 --- /dev/null +++ b/conservancy_beancount/_typing.py @@ -0,0 +1,58 @@ +"""Type definitions for Conservancy Beancount code""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import abc +import datetime + +import beancount.core.data as bc_data + +from typing import ( + Any, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, +) + +Account = bc_data.Account +HookName = str +MetaKey = str +MetaValue = Any +MetaValueEnum = str +Posting = bc_data.Posting + +class LessComparable(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __le__(self, other: Any) -> bool: ... + + @abc.abstractmethod + def __lt__(self, other: Any) -> bool: ... + + +class Directive(NamedTuple): + meta: bc_data.Meta + date: datetime.date + + +class Transaction(Directive): + flag: bc_data.Flag + payee: Optional[str] + narration: str + tags: Set + links: Set + postings: List[Posting] diff --git a/conservancy_beancount/plugin/core.py b/conservancy_beancount/plugin/core.py index ddf2437..98b70f4 100644 --- a/conservancy_beancount/plugin/core.py +++ b/conservancy_beancount/plugin/core.py @@ -19,14 +19,41 @@ import re from . import errors as errormod +from typing import ( + AbstractSet, + Any, + ClassVar, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) +from .._typing import ( + Account, + HookName, + LessComparable, + MetaKey, + MetaValue, + MetaValueEnum, + Posting, + Transaction, +) + # I expect these will become configurable in the future, which is why I'm # keeping them outside of a class, but for now constants will do. -DEFAULT_START_DATE = datetime.date(2020, 3, 1) +DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1) # The default stop date leaves a little room after so it's easy to test # dates past the far end of the range. -DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1) +DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1) -class _GenericRange: +CT = TypeVar('CT', bound=LessComparable) + +class _GenericRange(Generic[CT]): """Convenience class to check whether a value is within a range. `foo in generic_range` is equivalent to `start <= foo < stop`. @@ -35,17 +62,17 @@ class _GenericRange: makes it easier for subclasses to override. """ - def __init__(self, start, stop): + def __init__(self, start: CT, stop: CT) -> None: self.start = start self.stop = stop - def __repr__(self): + def __repr__(self) -> str: return "{clsname}({self.start!r}, {self.stop!r})".format( clsname=type(self).__name__, self=self, ) - def __contains__(self, item): + def __contains__(self, item: CT) -> bool: return self.start <= item < self.stop @@ -57,7 +84,11 @@ class MetadataEnum: the primary values. """ - def __init__(self, key, standard_values, aliases_map): + def __init__(self, + key: MetaKey, + standard_values: Iterable[MetaValueEnum], + aliases_map: Mapping[MetaValueEnum, MetaValueEnum], + ) -> None: """Specify allowed values and aliases for this metadata. Arguments: @@ -76,25 +107,28 @@ class MetadataEnum: self._aliases.update((v, v) for v in standard_values) assert self._stdvalues == set(self._aliases.values()) - def __repr__(self): + def __repr__(self) -> str: return "{}<{}>".format(type(self).__name__, self.key) - def __contains__(self, key): + def __contains__(self, key: MetaValueEnum) -> bool: """Returns true if `key` is a standard value or alias.""" return key in self._aliases - def __getitem__(self, key): + def __getitem__(self, key: MetaValueEnum) -> MetaValueEnum: """Return the standard value for `key`. Raises KeyError if `key` is not a known value or alias. """ return self._aliases[key] - def __iter__(self): + def __iter__(self) -> Iterator[MetaValueEnum]: """Iterate over standard values.""" return iter(self._stdvalues) - def get(self, key, default_key=None): + def get(self, + key: MetaValueEnum, + default_key: Optional[MetaValueEnum]=None, + ) -> Optional[MetaValueEnum]: """Return self[key], or a default fallback if that doesn't exist. default_key is another key to look up, *not* a default value to return. @@ -121,18 +155,30 @@ class PostingChecker: # Subclasses may wish to override _default_value and _should_check. # See below. - HOOK_GROUPS = frozenset(['Posting', 'metadata']) - ACCOUNTS = ('',) - TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) - VALUES_ENUM = {} + METADATA_KEY: ClassVar[MetaKey] + VALUES_ENUM: MetadataEnum + HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata']) + ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',) + TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) - def _meta_get(self, txn, post, key, default=None): - try: + def _meta_get(self, + txn: Transaction, + post: Posting, + key: MetaKey, + default: MetaValue=None, + ) -> MetaValue: + if post.meta and key in post.meta: return post.meta[key] - except (KeyError, TypeError): + else: return txn.meta.get(key, default) - def _meta_set(self, txn, post, post_index, key, value): + def _meta_set(self, + txn: Transaction, + post: Posting, + post_index: int, + key: MetaKey, + value: MetaValue, + ) -> None: if post.meta is None: txn.postings[post_index] = Posting(*post[:5], {key: value}) else: @@ -142,22 +188,22 @@ class PostingChecker: # _default_value to get a default. This method should either return # a value string from METADATA_ENUM, or else raise InvalidMetadataError. # This base implementation does the latter. - def _default_value(self, txn, post): + def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY) # The hook calls _should_check on every posting and only checks postings # when the method returns true. This base method checks the transaction # date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS. - def _should_check(self, txn, post): + def _should_check(self, txn: Transaction, post: Posting) -> bool: ok = txn.date in self.TXN_DATE_RANGE if isinstance(self.ACCOUNTS, tuple): ok = ok and post.account.startswith(self.ACCOUNTS) else: - ok = ok and re.search(self.ACCOUNTS, post.account) + ok = ok and bool(re.search(self.ACCOUNTS, post.account)) return ok - def run(self, txn, post): - errors = [] + def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]: + errors: List[errormod._BaseError] = [] if not self._should_check(txn, post): return errors source_value = self._meta_get(txn, post, self.METADATA_KEY) @@ -175,5 +221,5 @@ class PostingChecker: txn, post, self.METADATA_KEY, source_value, )) if not errors: - self._meta_set(post, self.METADATA_KEY, set_value) + self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value) return errors