plugin: Introduce HookRegistry.

This is the layer that keeps track of the different groups of hooks and
can filter them before runtime. The idea here is that you'll be able
to do things like skip hooks that require network access when you don't
have it, or skip CPU-intensive hooks when you don't need them, etc.
This commit is contained in:
Brett Smith 2020-03-06 09:22:55 -05:00
parent d145e22734
commit d34db71542
4 changed files with 113 additions and 14 deletions

View file

@ -14,6 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import importlib
import beancount.core.data as bc_data
__plugins__ = ['run']
@ -24,18 +26,61 @@ class HookRegistry:
'Posting',
])
@classmethod
def group_by_directive(cls, hooks_seq):
hooks_map = {key: [] for key in cls.DIRECTIVES}
for hook in hooks_seq:
for key in cls.DIRECTIVES & hook.HOOK_GROUPS:
hooks_map[key].append(hook)
return hooks_map
def __init__(self):
self.group_hooks_map = {key: set() for key in self.DIRECTIVES}
def add_hook(self, hook_cls):
hook_groups = list(hook_cls.HOOK_GROUPS)
assert self.DIRECTIVES.intersection(hook_groups)
hook_groups.append('all')
for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']:
try:
hook_name = getattr(hook_cls, name_attr)
except AttributeError:
pass
else:
hook_groups.append(hook_name)
break
for key in hook_groups:
self.group_hooks_map.setdefault(key, set()).add(hook_cls)
return hook_cls # to allow use as a decorator
def import_hooks(self, mod_name, *hook_names, package=__module__):
module = importlib.import_module(mod_name, package)
for hook_name in hook_names:
self.add_hook(getattr(module, hook_name))
def group_by_directive(self, config_str=''):
config_str = config_str.strip()
if not config_str:
config_str = 'all'
elif config_str.startswith('-'):
config_str = 'all ' + config_str
available_hooks = set()
for token in config_str.split():
if token.startswith('-'):
update_available = available_hooks.difference_update
key = token[1:]
else:
update_available = available_hooks.update
key = token
try:
update_set = self.group_hooks_map[key]
except KeyError:
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
else:
update_available(update_set)
return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks]
for key in self.DIRECTIVES}
def run(entries, options_map, config):
HOOK_REGISTRY = HookRegistry()
HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
errors = []
hooks = HookRegistry.group_by_directive(config)
hooks = hook_registry.group_by_directive(config)
for entry in entries:
entry_type = type(entry).__name__
for hook in hooks[entry_type]:

View file

@ -67,6 +67,7 @@ class MetadataEnum:
class PostingChecker:
HOOK_GROUPS = frozenset(['Posting', 'metadata'])
ACCOUNTS = ('',)
TXN_DATE_RANGE = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
VALUES_ENUM = {}

View file

@ -0,0 +1,53 @@
"""Test main plugin's HookRegistry"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import pytest
from . import testutil
from conservancy_beancount import plugin
def hook_names(hooks, key):
return {type(hook).__name__ for hook in hooks[key]}
def test_default_registrations():
hooks = plugin.HOOK_REGISTRY.group_by_directive()
post_hook_names = hook_names(hooks, 'Posting')
assert len(post_hook_names) >= 2
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' in post_hook_names
def test_exclude_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('-expenseAllocation')
post_hook_names = hook_names(hooks, 'Posting')
assert post_hook_names
assert 'MetaExpenseAllocation' not in post_hook_names
def test_exclude_group_then_include_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expenseAllocation')
post_hook_names = hook_names(hooks, 'Posting')
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' not in post_hook_names
def test_include_group_then_exclude_single():
hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -taxImplication')
post_hook_names = hook_names(hooks, 'Posting')
assert 'MetaExpenseAllocation' in post_hook_names
assert 'MetaTaxImplication' not in post_hook_names
def test_unknown_group_name():
with pytest.raises(ValueError):
plugin.HOOK_REGISTRY.group_by_directive('UnKnownTestGroup')

View file

@ -21,7 +21,9 @@ from . import testutil
from conservancy_beancount import plugin
CONFIG_MAP = {}
HOOK_REGISTRY = plugin.HookRegistry()
@HOOK_REGISTRY.add_hook
class TransactionCounter:
HOOK_GROUPS = frozenset(['Transaction', 'counter'])
@ -29,6 +31,7 @@ class TransactionCounter:
return ['txn:{}'.format(id(txn))]
@HOOK_REGISTRY.add_hook
class PostingCounter(TransactionCounter):
HOOK_GROUPS = frozenset(['Posting', 'counter'])
@ -44,8 +47,6 @@ def map_errors(errors):
return retval
def test_with_multiple_hooks():
txn_counter = TransactionCounter()
post_counter = PostingCounter()
in_entries = [
testutil.Transaction(postings=[
('Income:Donations', -25),
@ -56,14 +57,13 @@ def test_with_multiple_hooks():
('Liabilites:CreditCard', -10),
]),
]
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [txn_counter, post_counter])
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, '', HOOK_REGISTRY)
assert len(out_entries) == 2
errmap = map_errors(errors)
assert len(errmap.get('txn', '')) == 2
assert len(errmap.get('post', '')) == 4
def test_with_posting_hooks_only():
post_counter = PostingCounter()
in_entries = [
testutil.Transaction(postings=[
('Income:Donations', -25),
@ -74,7 +74,7 @@ def test_with_posting_hooks_only():
('Liabilites:CreditCard', -10),
]),
]
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, [post_counter])
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY)
assert len(out_entries) == 2
errmap = map_errors(errors)
assert len(errmap.get('txn', '')) == 0