plugin.core: Add type hints.

This commit is contained in:
Brett Smith 2020-03-08 18:24:51 -04:00
parent 547ae65780
commit ee038d7b7d
2 changed files with 130 additions and 26 deletions

View file

@ -0,0 +1,58 @@
"""Type definitions for Conservancy Beancount code"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import abc
import datetime
import beancount.core.data as bc_data
from typing import (
Any,
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
)
Account = bc_data.Account
HookName = str
MetaKey = str
MetaValue = Any
MetaValueEnum = str
Posting = bc_data.Posting
class LessComparable(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __le__(self, other: Any) -> bool: ...
@abc.abstractmethod
def __lt__(self, other: Any) -> bool: ...
class Directive(NamedTuple):
meta: bc_data.Meta
date: datetime.date
class Transaction(Directive):
flag: bc_data.Flag
payee: Optional[str]
narration: str
tags: Set
links: Set
postings: List[Posting]

View file

@ -19,14 +19,41 @@ import re
from . import errors as errormod
from typing import (
AbstractSet,
Any,
ClassVar,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)
from .._typing import (
Account,
HookName,
LessComparable,
MetaKey,
MetaValue,
MetaValueEnum,
Posting,
Transaction,
)
# I expect these will become configurable in the future, which is why I'm
# keeping them outside of a class, but for now constants will do.
DEFAULT_START_DATE = datetime.date(2020, 3, 1)
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
# The default stop date leaves a little room after so it's easy to test
# dates past the far end of the range.
DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1)
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
class _GenericRange:
CT = TypeVar('CT', bound=LessComparable)
class _GenericRange(Generic[CT]):
"""Convenience class to check whether a value is within a range.
`foo in generic_range` is equivalent to `start <= foo < stop`.
@ -35,17 +62,17 @@ class _GenericRange:
makes it easier for subclasses to override.
"""
def __init__(self, start, stop):
def __init__(self, start: CT, stop: CT) -> None:
self.start = start
self.stop = stop
def __repr__(self):
def __repr__(self) -> str:
return "{clsname}({self.start!r}, {self.stop!r})".format(
clsname=type(self).__name__,
self=self,
)
def __contains__(self, item):
def __contains__(self, item: CT) -> bool:
return self.start <= item < self.stop
@ -57,7 +84,11 @@ class MetadataEnum:
the primary values.
"""
def __init__(self, key, standard_values, aliases_map):
def __init__(self,
key: MetaKey,
standard_values: Iterable[MetaValueEnum],
aliases_map: Mapping[MetaValueEnum, MetaValueEnum],
) -> None:
"""Specify allowed values and aliases for this metadata.
Arguments:
@ -76,25 +107,28 @@ class MetadataEnum:
self._aliases.update((v, v) for v in standard_values)
assert self._stdvalues == set(self._aliases.values())
def __repr__(self):
def __repr__(self) -> str:
return "{}<{}>".format(type(self).__name__, self.key)
def __contains__(self, key):
def __contains__(self, key: MetaValueEnum) -> bool:
"""Returns true if `key` is a standard value or alias."""
return key in self._aliases
def __getitem__(self, key):
def __getitem__(self, key: MetaValueEnum) -> MetaValueEnum:
"""Return the standard value for `key`.
Raises KeyError if `key` is not a known value or alias.
"""
return self._aliases[key]
def __iter__(self):
def __iter__(self) -> Iterator[MetaValueEnum]:
"""Iterate over standard values."""
return iter(self._stdvalues)
def get(self, key, default_key=None):
def get(self,
key: MetaValueEnum,
default_key: Optional[MetaValueEnum]=None,
) -> Optional[MetaValueEnum]:
"""Return self[key], or a default fallback if that doesn't exist.
default_key is another key to look up, *not* a default value to return.
@ -121,18 +155,30 @@ class PostingChecker:
# Subclasses may wish to override _default_value and _should_check.
# See below.
HOOK_GROUPS = frozenset(['Posting', 'metadata'])
ACCOUNTS = ('',)
TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
VALUES_ENUM = {}
METADATA_KEY: ClassVar[MetaKey]
VALUES_ENUM: MetadataEnum
HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata'])
ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',)
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
def _meta_get(self, txn, post, key, default=None):
try:
def _meta_get(self,
txn: Transaction,
post: Posting,
key: MetaKey,
default: MetaValue=None,
) -> MetaValue:
if post.meta and key in post.meta:
return post.meta[key]
except (KeyError, TypeError):
else:
return txn.meta.get(key, default)
def _meta_set(self, txn, post, post_index, key, value):
def _meta_set(self,
txn: Transaction,
post: Posting,
post_index: int,
key: MetaKey,
value: MetaValue,
) -> None:
if post.meta is None:
txn.postings[post_index] = Posting(*post[:5], {key: value})
else:
@ -142,22 +188,22 @@ class PostingChecker:
# _default_value to get a default. This method should either return
# a value string from METADATA_ENUM, or else raise InvalidMetadataError.
# This base implementation does the latter.
def _default_value(self, txn, post):
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
# The hook calls _should_check on every posting and only checks postings
# when the method returns true. This base method checks the transaction
# date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS.
def _should_check(self, txn, post):
def _should_check(self, txn: Transaction, post: Posting) -> bool:
ok = txn.date in self.TXN_DATE_RANGE
if isinstance(self.ACCOUNTS, tuple):
ok = ok and post.account.startswith(self.ACCOUNTS)
else:
ok = ok and re.search(self.ACCOUNTS, post.account)
ok = ok and bool(re.search(self.ACCOUNTS, post.account))
return ok
def run(self, txn, post):
errors = []
def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]:
errors: List[errormod._BaseError] = []
if not self._should_check(txn, post):
return errors
source_value = self._meta_get(txn, post, self.METADATA_KEY)
@ -175,5 +221,5 @@ class PostingChecker:
txn, post, self.METADATA_KEY, source_value,
))
if not errors:
self._meta_set(post, self.METADATA_KEY, set_value)
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
return errors