plugin: Transform posting hooks into transaction hooks.

I feel like posting hooks a case of premature optimization in early
development. This approach reduces the number of special cases in
the code and allows us to more strongly reason about hooks in the
type system.
This commit is contained in:
Brett Smith 2020-03-15 15:50:14 -04:00
parent c9ff4ab746
commit a41feb94b3
9 changed files with 180 additions and 122 deletions

View file

@ -18,9 +18,12 @@ import abc
import datetime
import beancount.core.data as bc_data
from .plugin import errors
from typing import (
Any,
FrozenSet,
Iterable,
List,
NamedTuple,
Optional,
@ -30,7 +33,8 @@ from typing import (
)
Account = bc_data.Account
HookName = str
Error = errors._BaseError
ErrorIter = Iterable[Error]
MetaKey = str
MetaValue = Any
MetaValueEnum = str
@ -56,3 +60,8 @@ class Transaction(Directive):
tags: Set
links: Set
postings: List[Posting]
ALL_DIRECTIVES: FrozenSet[Type[Directive]] = frozenset([
Transaction,
])

View file

@ -18,31 +18,40 @@ import importlib
import beancount.core.data as bc_data
from typing import (
AbstractSet,
Any,
Dict,
List,
Mapping,
Set,
Tuple,
Type,
)
from .._typing import (
ALL_DIRECTIVES,
Directive,
Error,
)
from .core import (
Hook,
HookName,
)
__plugins__ = ['run']
class HookRegistry:
DIRECTIVES = frozenset([
*(cls.__name__ for cls in bc_data.ALL_DIRECTIVES),
'Posting',
])
def __init__(self) -> None:
self.group_name_map: Dict[HookName, Set[Type[Hook]]] = {
t.__name__: set() for t in ALL_DIRECTIVES
}
self.group_name_map['all'] = set()
def __init__(self):
self.group_hooks_map = {key: set() for key in self.DIRECTIVES}
def add_hook(self, hook_cls):
hook_groups = list(hook_cls.HOOK_GROUPS)
assert self.DIRECTIVES.intersection(hook_groups)
hook_groups.append('all')
for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']:
try:
hook_name = getattr(hook_cls, name_attr)
except AttributeError:
pass
else:
hook_groups.append(hook_name)
break
for key in hook_groups:
self.group_hooks_map.setdefault(key, set()).add(hook_cls)
def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]:
self.group_name_map['all'].add(hook_cls)
self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls)
for key in hook_cls.HOOK_GROUPS:
self.group_name_map.setdefault(key, set()).add(hook_cls)
return hook_cls # to allow use as a decorator
def import_hooks(self, mod_name, *hook_names, package=__module__):
@ -50,13 +59,13 @@ class HookRegistry:
for hook_name in hook_names:
self.add_hook(getattr(module, hook_name))
def group_by_directive(self, config_str=''):
def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]:
config_str = config_str.strip()
if not config_str:
config_str = 'all'
elif config_str.startswith('-'):
config_str = 'all ' + config_str
available_hooks = set()
available_hooks: Set[Type[Hook]] = set()
for token in config_str.split():
if token.startswith('-'):
update_available = available_hooks.difference_update
@ -65,29 +74,32 @@ class HookRegistry:
update_available = available_hooks.update
key = token
try:
update_set = self.group_hooks_map[key]
update_set = self.group_name_map[key]
except KeyError:
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
else:
update_available(update_set)
return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks]
for key in self.DIRECTIVES}
return {
t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks]
for t in ALL_DIRECTIVES
}
HOOK_REGISTRY = HookRegistry()
HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
errors = []
def run(
entries: List[Directive],
options_map: Dict[str, Any],
config: str='',
hook_registry: HookRegistry=HOOK_REGISTRY,
) -> Tuple[List[Directive], List[Error]]:
errors: List[Error] = []
hooks = hook_registry.group_by_directive(config)
for entry in entries:
entry_type = type(entry).__name__
for hook in hooks[entry_type]:
errors.extend(hook.run(entry))
if entry_type == 'Transaction':
for index, post in enumerate(entry.postings):
for hook in hooks['Posting']:
errors.extend(hook.run(entry, post, index))
return entries, errors

View file

@ -14,36 +14,37 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import abc
import datetime
import re
from . import errors as errormod
from typing import (
AbstractSet,
Any,
ClassVar,
FrozenSet,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)
from .._typing import (
Account,
HookName,
Directive,
Error,
ErrorIter,
LessComparable,
MetaKey,
MetaValue,
MetaValueEnum,
Posting,
Transaction,
Type,
)
### CONSTANTS
# I expect these will become configurable in the future, which is why I'm
# keeping them outside of a class, but for now constants will do.
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
@ -51,8 +52,27 @@ DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
# dates past the far end of the range.
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
CT = TypeVar('CT', bound=LessComparable)
### TYPE DEFINITIONS
HookName = str
Entry = TypeVar('Entry', bound=Directive)
class Hook(Generic[Entry], metaclass=abc.ABCMeta):
DIRECTIVE: Type[Directive]
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
@abc.abstractmethod
def run(self, entry: Entry) -> ErrorIter: ...
def __init_subclass__(cls):
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
TransactionHook = Hook[Transaction]
### HELPER CLASSES
CT = TypeVar('CT', bound=LessComparable)
class _GenericRange(Generic[CT]):
"""Convenience class to check whether a value is within a range.
@ -143,24 +163,14 @@ class MetadataEnum:
return self[default_key]
class PostingChecker:
"""Base class to normalize posting metadata from an enum."""
# This class provides basic functionality to filter postings, normalize
# metadata values, and set default values.
# Subclasses should set:
# * METADATA_KEY: A string with the name of the metadata key to normalize.
# * ACCOUNTS: Only check postings that match these account names.
# Can be a tuple of account prefix strings, or a regexp.
# * VALUES_ENUM: A MetadataEnum with allowed values and aliases.
# Subclasses may wish to override _default_value and _should_check.
# See below.
### HOOK SUBCLASSES
METADATA_KEY: ClassVar[MetaKey]
VALUES_ENUM: MetadataEnum
HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata'])
ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',)
class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
def __init_subclass__(cls) -> None:
cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting'])
def _meta_get(self,
txn: Transaction,
post: Posting,
@ -184,6 +194,34 @@ class PostingChecker:
else:
post.meta[key] = value
def _run_on_txn(self, txn: Transaction) -> bool:
return txn.date in self.TXN_DATE_RANGE
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
return True
def run(self, txn: Transaction) -> ErrorIter:
if self._run_on_txn(txn):
for index, post in enumerate(txn.postings):
if self._run_on_post(txn, post):
yield from self.post_run(txn, post, index)
@abc.abstractmethod
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
class _NormalizePostingMetadataHook(_PostingHook):
"""Base class to normalize posting metadata from an enum."""
# This class provides basic functionality to filter postings, normalize
# metadata values, and set default values.
METADATA_KEY: MetaKey
VALUES_ENUM: MetadataEnum
def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls.METADATA_KEY = cls.VALUES_ENUM.key
cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['metadata', cls.METADATA_KEY])
# If the posting does not specify METADATA_KEY, the hook calls
# _default_value to get a default. This method should either return
# a value string from METADATA_ENUM, or else raise InvalidMetadataError.
@ -191,35 +229,23 @@ class PostingChecker:
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
# The hook calls _should_check on every posting and only checks postings
# when the method returns true. This base method checks the transaction
# date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS.
def _should_check(self, txn: Transaction, post: Posting) -> bool:
ok = txn.date in self.TXN_DATE_RANGE
if isinstance(self.ACCOUNTS, tuple):
ok = ok and post.account.startswith(self.ACCOUNTS)
else:
ok = ok and bool(re.search(self.ACCOUNTS, post.account))
return ok
def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]:
errors: List[errormod._BaseError] = []
if not self._should_check(txn, post):
return errors
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
source_value = self._meta_get(txn, post, self.METADATA_KEY)
set_value = source_value
error: Optional[Error] = None
if source_value is None:
try:
set_value = self._default_value(txn, post)
except errormod._BaseError as error:
errors.append(error)
except errormod._BaseError as error_:
error = error_
else:
try:
set_value = self.VALUES_ENUM[source_value]
except KeyError:
errors.append(errormod.InvalidMetadataError(
error = errormod.InvalidMetadataError(
txn, post, self.METADATA_KEY, source_value,
))
if not errors:
)
if error is None:
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
return errors
else:
yield error

View file

@ -15,11 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from . import core
from .._typing import (
MetaValueEnum,
Posting,
Transaction,
)
class MetaExpenseAllocation(core.PostingChecker):
ACCOUNTS = ('Expenses:',)
METADATA_KEY = 'expense-allocation'
VALUES_ENUM = core.MetadataEnum(METADATA_KEY, {
class MetaExpenseAllocation(core._NormalizePostingMetadataHook):
VALUES_ENUM = core.MetadataEnum('expense-allocation', {
'administration',
'fundraising',
'program',
@ -32,5 +35,8 @@ class MetaExpenseAllocation(core.PostingChecker):
'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'],
}
def _default_value(self, txn, post):
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
return post.account.startswith('Expenses:')
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
return self.DEFAULT_VALUES.get(post.account, 'program')

View file

@ -17,13 +17,15 @@
import decimal
from . import core
from .._typing import (
Posting,
Transaction,
)
DEFAULT_STOP_AMOUNT = decimal.Decimal(0)
class MetaTaxImplication(core.PostingChecker):
ACCOUNTS = ('Assets:',)
METADATA_KEY = 'tax-implication'
VALUES_ENUM = core.MetadataEnum(METADATA_KEY, [
class MetaTaxImplication(core._NormalizePostingMetadataHook):
VALUES_ENUM = core.MetadataEnum('tax-implication', [
'1099',
'Accountant-Advises-No-1099',
'Bank-Transfer',
@ -43,8 +45,9 @@ class MetaTaxImplication(core.PostingChecker):
'W2',
], {})
def _should_check(self, txn, post):
return (
super()._should_check(txn, post)
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
return bool(
post.account.startswith('Assets:')
and post.units.number
and post.units.number < DEFAULT_STOP_AMOUNT
)

View file

@ -44,7 +44,7 @@ def test_valid_values_on_postings(src_value, set_value):
('Expenses:General', 25, {TEST_KEY: src_value}),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
@ -55,7 +55,7 @@ def test_invalid_values_on_postings(src_value):
('Expenses:General', 25, {TEST_KEY: src_value}),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert errors
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
@ -65,7 +65,7 @@ def test_valid_values_on_transactions(src_value, set_value):
('Expenses:General', 25),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
@ -76,7 +76,7 @@ def test_invalid_values_on_transactions(src_value):
('Expenses:General', 25),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert errors
@pytest.mark.parametrize('account', [
@ -92,7 +92,7 @@ def test_non_expense_accounts_skipped(account):
('Expenses:General', 25, {TEST_KEY: 'program'}),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[0], 0)
errors = list(checker.run(txn))
assert not errors
@pytest.mark.parametrize('account,set_value', [
@ -108,7 +108,7 @@ def test_default_values(account, set_value):
(account, 25),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert txn.postings[-1].meta[TEST_KEY] == set_value
@ -125,7 +125,7 @@ def test_default_value_set_in_date_range(date, set_value):
('Expenses:General', 25),
])
checker = meta_expense_allocation.MetaExpenseAllocation()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
got_value = (txn.postings[-1].meta or {}).get(TEST_KEY)
assert bool(got_value) == bool(set_value)

View file

@ -56,7 +56,7 @@ def test_valid_values_on_postings(src_value, set_value):
('Assets:Cash', -25, {TEST_KEY: src_value}),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
@ -67,7 +67,7 @@ def test_invalid_values_on_postings(src_value):
('Assets:Cash', -25, {TEST_KEY: src_value}),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert errors
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
@ -77,7 +77,7 @@ def test_valid_values_on_transactions(src_value, set_value):
('Assets:Cash', -25),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
@ -88,7 +88,7 @@ def test_invalid_values_on_transactions(src_value):
('Assets:Cash', -25),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert errors
@pytest.mark.parametrize('account', [
@ -102,7 +102,7 @@ def test_non_asset_accounts_skipped(account):
('Assets:Cash', -25, {TEST_KEY: 'USA-Corporation'}),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[0], 0)
errors = list(checker.run(txn))
assert not errors
def test_asset_credits_skipped():
@ -111,7 +111,7 @@ def test_asset_credits_skipped():
('Assets:Cash', 25),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert not errors
assert not txn.postings[-1].meta
@ -128,5 +128,5 @@ def test_default_value_set_in_date_range(date, need_value):
('Assets:Cash', -25),
])
checker = meta_tax_implication.MetaTaxImplication()
errors = checker.run(txn, txn.postings[-1], -1)
errors = list(checker.run(txn))
assert bool(errors) == bool(need_value)

View file

@ -25,28 +25,28 @@ def hook_names(hooks, key):
def test_default_registrations():
hooks = plugin.HOOK_REGISTRY.group_by_directive()
post_hook_names = hook_names(hooks, 'Posting')
assert len(post_hook_names) >= 2
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' in post_hook_names
txn_hook_names = hook_names(hooks, 'Transaction')
assert len(txn_hook_names) >= 2
assert 'MetaExpenseAllocation' in txn_hook_names
assert 'MetaTaxImplication' in txn_hook_names
def test_exclude_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('-expense-allocation')
post_hook_names = hook_names(hooks, 'Posting')
assert post_hook_names
assert 'MetaExpenseAllocation' not in post_hook_names
txn_hook_names = hook_names(hooks, 'Transaction')
assert txn_hook_names
assert 'MetaExpenseAllocation' not in txn_hook_names
def test_exclude_group_then_include_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expense-allocation')
post_hook_names = hook_names(hooks, 'Posting')
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' not in post_hook_names
txn_hook_names = hook_names(hooks, 'Transaction')
assert 'MetaExpenseAllocation' in txn_hook_names
assert 'MetaTaxImplication' not in txn_hook_names
def test_include_group_then_exclude_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -tax-implication')
post_hook_names = hook_names(hooks, 'Posting')
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' not in post_hook_names
txn_hook_names = hook_names(hooks, 'Transaction')
assert 'MetaExpenseAllocation' in txn_hook_names
assert 'MetaTaxImplication' not in txn_hook_names
def test_unknown_group_name():
with pytest.raises(ValueError):

View file

@ -18,14 +18,15 @@ import pytest
from . import testutil
from conservancy_beancount import plugin
from conservancy_beancount import plugin, _typing
CONFIG_MAP = {}
HOOK_REGISTRY = plugin.HookRegistry()
@HOOK_REGISTRY.add_hook
class TransactionCounter:
HOOK_GROUPS = frozenset(['Transaction', 'counter'])
DIRECTIVE = _typing.Transaction
HOOK_GROUPS = frozenset()
def run(self, txn):
return ['txn:{}'.format(id(txn))]
@ -33,10 +34,11 @@ class TransactionCounter:
@HOOK_REGISTRY.add_hook
class PostingCounter(TransactionCounter):
HOOK_GROUPS = frozenset(['Posting', 'counter'])
DIRECTIVE = _typing.Transaction
HOOK_GROUPS = frozenset(['posting'])
def run(self, txn, post, post_index):
return ['post:{}'.format(id(post))]
def run(self, txn):
return ['post:{}'.format(id(post)) for post in txn.postings]
def map_errors(errors):
@ -74,7 +76,7 @@ def test_with_posting_hooks_only():
('Liabilites:CreditCard', -10),
]),
]
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY)
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'posting', HOOK_REGISTRY)
assert len(out_entries) == 2
errmap = map_errors(errors)
assert len(errmap.get('txn', '')) == 0