plugin.core: Add type hints.
This commit is contained in:
parent
547ae65780
commit
ee038d7b7d
2 changed files with 130 additions and 26 deletions
58
conservancy_beancount/_typing.py
Normal file
58
conservancy_beancount/_typing.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
"""Type definitions for Conservancy Beancount code"""
|
||||
# Copyright © 2020 Brett Smith
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
import abc
|
||||
import datetime
|
||||
|
||||
import beancount.core.data as bc_data
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
Account = bc_data.Account
|
||||
HookName = str
|
||||
MetaKey = str
|
||||
MetaValue = Any
|
||||
MetaValueEnum = str
|
||||
Posting = bc_data.Posting
|
||||
|
||||
class LessComparable(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __le__(self, other: Any) -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __lt__(self, other: Any) -> bool: ...
|
||||
|
||||
|
||||
class Directive(NamedTuple):
|
||||
meta: bc_data.Meta
|
||||
date: datetime.date
|
||||
|
||||
|
||||
class Transaction(Directive):
|
||||
flag: bc_data.Flag
|
||||
payee: Optional[str]
|
||||
narration: str
|
||||
tags: Set
|
||||
links: Set
|
||||
postings: List[Posting]
|
|
@ -19,14 +19,41 @@ import re
|
|||
|
||||
from . import errors as errormod
|
||||
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from .._typing import (
|
||||
Account,
|
||||
HookName,
|
||||
LessComparable,
|
||||
MetaKey,
|
||||
MetaValue,
|
||||
MetaValueEnum,
|
||||
Posting,
|
||||
Transaction,
|
||||
)
|
||||
|
||||
# I expect these will become configurable in the future, which is why I'm
|
||||
# keeping them outside of a class, but for now constants will do.
|
||||
DEFAULT_START_DATE = datetime.date(2020, 3, 1)
|
||||
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
|
||||
# The default stop date leaves a little room after so it's easy to test
|
||||
# dates past the far end of the range.
|
||||
DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1)
|
||||
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
|
||||
|
||||
class _GenericRange:
|
||||
CT = TypeVar('CT', bound=LessComparable)
|
||||
|
||||
class _GenericRange(Generic[CT]):
|
||||
"""Convenience class to check whether a value is within a range.
|
||||
|
||||
`foo in generic_range` is equivalent to `start <= foo < stop`.
|
||||
|
@ -35,17 +62,17 @@ class _GenericRange:
|
|||
makes it easier for subclasses to override.
|
||||
"""
|
||||
|
||||
def __init__(self, start, stop):
|
||||
def __init__(self, start: CT, stop: CT) -> None:
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{clsname}({self.start!r}, {self.stop!r})".format(
|
||||
clsname=type(self).__name__,
|
||||
self=self,
|
||||
)
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item: CT) -> bool:
|
||||
return self.start <= item < self.stop
|
||||
|
||||
|
||||
|
@ -57,7 +84,11 @@ class MetadataEnum:
|
|||
the primary values.
|
||||
"""
|
||||
|
||||
def __init__(self, key, standard_values, aliases_map):
|
||||
def __init__(self,
|
||||
key: MetaKey,
|
||||
standard_values: Iterable[MetaValueEnum],
|
||||
aliases_map: Mapping[MetaValueEnum, MetaValueEnum],
|
||||
) -> None:
|
||||
"""Specify allowed values and aliases for this metadata.
|
||||
|
||||
Arguments:
|
||||
|
@ -76,25 +107,28 @@ class MetadataEnum:
|
|||
self._aliases.update((v, v) for v in standard_values)
|
||||
assert self._stdvalues == set(self._aliases.values())
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{}<{}>".format(type(self).__name__, self.key)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: MetaValueEnum) -> bool:
|
||||
"""Returns true if `key` is a standard value or alias."""
|
||||
return key in self._aliases
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: MetaValueEnum) -> MetaValueEnum:
|
||||
"""Return the standard value for `key`.
|
||||
|
||||
Raises KeyError if `key` is not a known value or alias.
|
||||
"""
|
||||
return self._aliases[key]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[MetaValueEnum]:
|
||||
"""Iterate over standard values."""
|
||||
return iter(self._stdvalues)
|
||||
|
||||
def get(self, key, default_key=None):
|
||||
def get(self,
|
||||
key: MetaValueEnum,
|
||||
default_key: Optional[MetaValueEnum]=None,
|
||||
) -> Optional[MetaValueEnum]:
|
||||
"""Return self[key], or a default fallback if that doesn't exist.
|
||||
|
||||
default_key is another key to look up, *not* a default value to return.
|
||||
|
@ -121,18 +155,30 @@ class PostingChecker:
|
|||
# Subclasses may wish to override _default_value and _should_check.
|
||||
# See below.
|
||||
|
||||
HOOK_GROUPS = frozenset(['Posting', 'metadata'])
|
||||
ACCOUNTS = ('',)
|
||||
TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
|
||||
VALUES_ENUM = {}
|
||||
METADATA_KEY: ClassVar[MetaKey]
|
||||
VALUES_ENUM: MetadataEnum
|
||||
HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata'])
|
||||
ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',)
|
||||
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
|
||||
|
||||
def _meta_get(self, txn, post, key, default=None):
|
||||
try:
|
||||
def _meta_get(self,
|
||||
txn: Transaction,
|
||||
post: Posting,
|
||||
key: MetaKey,
|
||||
default: MetaValue=None,
|
||||
) -> MetaValue:
|
||||
if post.meta and key in post.meta:
|
||||
return post.meta[key]
|
||||
except (KeyError, TypeError):
|
||||
else:
|
||||
return txn.meta.get(key, default)
|
||||
|
||||
def _meta_set(self, txn, post, post_index, key, value):
|
||||
def _meta_set(self,
|
||||
txn: Transaction,
|
||||
post: Posting,
|
||||
post_index: int,
|
||||
key: MetaKey,
|
||||
value: MetaValue,
|
||||
) -> None:
|
||||
if post.meta is None:
|
||||
txn.postings[post_index] = Posting(*post[:5], {key: value})
|
||||
else:
|
||||
|
@ -142,22 +188,22 @@ class PostingChecker:
|
|||
# _default_value to get a default. This method should either return
|
||||
# a value string from METADATA_ENUM, or else raise InvalidMetadataError.
|
||||
# This base implementation does the latter.
|
||||
def _default_value(self, txn, post):
|
||||
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
||||
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
||||
|
||||
# The hook calls _should_check on every posting and only checks postings
|
||||
# when the method returns true. This base method checks the transaction
|
||||
# date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS.
|
||||
def _should_check(self, txn, post):
|
||||
def _should_check(self, txn: Transaction, post: Posting) -> bool:
|
||||
ok = txn.date in self.TXN_DATE_RANGE
|
||||
if isinstance(self.ACCOUNTS, tuple):
|
||||
ok = ok and post.account.startswith(self.ACCOUNTS)
|
||||
else:
|
||||
ok = ok and re.search(self.ACCOUNTS, post.account)
|
||||
ok = ok and bool(re.search(self.ACCOUNTS, post.account))
|
||||
return ok
|
||||
|
||||
def run(self, txn, post):
|
||||
errors = []
|
||||
def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]:
|
||||
errors: List[errormod._BaseError] = []
|
||||
if not self._should_check(txn, post):
|
||||
return errors
|
||||
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
||||
|
@ -175,5 +221,5 @@ class PostingChecker:
|
|||
txn, post, self.METADATA_KEY, source_value,
|
||||
))
|
||||
if not errors:
|
||||
self._meta_set(post, self.METADATA_KEY, set_value)
|
||||
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
|
||||
return errors
|
||||
|
|
Loading…
Reference in a new issue