plugin: Move hook initialization from HookRegistry to run().

Makes more sense here so run can report errors in hook configuration.
This commit is contained in:
Brett Smith 2020-03-19 16:42:13 -04:00
parent e424173216
commit 84d8adb7f6
2 changed files with 11 additions and 10 deletions

View file

@ -22,8 +22,8 @@ from typing import (
AbstractSet, AbstractSet,
Any, Any,
Dict, Dict,
Iterable,
List, List,
Mapping,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -62,7 +62,7 @@ class HookRegistry:
for hook_name in hook_names: for hook_name in hook_names:
self.add_hook(getattr(module, hook_name)) self.add_hook(getattr(module, hook_name))
def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]: def group_by_directive(self, config_str: str='') -> Iterable[Tuple[HookName, Type[Hook]]]:
config_str = config_str.strip() config_str = config_str.strip()
if not config_str: if not config_str:
config_str = 'all' config_str = 'all'
@ -82,10 +82,10 @@ class HookRegistry:
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
else: else:
update_available(update_set) update_available(update_set)
return { for directive in ALL_DIRECTIVES:
t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks] key = directive.__name__
for t in ALL_DIRECTIVES for hook in self.group_name_map[key] & available_hooks:
} yield key, hook
HOOK_REGISTRY = HookRegistry() HOOK_REGISTRY = HookRegistry()
@ -99,7 +99,9 @@ def run(
hook_registry: HookRegistry=HOOK_REGISTRY, hook_registry: HookRegistry=HOOK_REGISTRY,
) -> Tuple[List[Directive], List[Error]]: ) -> Tuple[List[Directive], List[Error]]:
errors: List[Error] = [] errors: List[Error] = []
hooks = hook_registry.group_by_directive(config) hooks: Dict[HookName, List[Hook]] = {}
for key, hook_type in hook_registry.group_by_directive(config):
hooks.setdefault(key, []).append(hook_type())
for entry in entries: for entry in entries:
entry_type = type(entry).__name__ entry_type = type(entry).__name__
for hook in hooks[entry_type]: for hook in hooks[entry_type]:

View file

@ -81,15 +81,14 @@ def map_errors(errors):
]) ])
def test_registry_group_by_directive(group_str, expected): def test_registry_group_by_directive(group_str, expected):
args = () if group_str is None else (group_str,) args = () if group_str is None else (group_str,)
hook_groups = HOOK_REGISTRY.group_by_directive(*args) actual = {hook for _, hook in HOOK_REGISTRY.group_by_directive(*args)}
actual = {type(hook) for hook in hook_groups['Transaction']}
assert actual.issuperset(expected) assert actual.issuperset(expected)
if len(expected) == 1: if len(expected) == 1:
assert not (TransactionError in actual and PostingError in actual) assert not (TransactionError in actual and PostingError in actual)
def test_registry_unknown_group_name(): def test_registry_unknown_group_name():
with pytest.raises(ValueError): with pytest.raises(ValueError):
HOOK_REGISTRY.group_by_directive('UnKnownTestGroup') next(HOOK_REGISTRY.group_by_directive('UnKnownTestGroup'))
def test_run_with_multiple_hooks(easy_entries, config_map): def test_run_with_multiple_hooks(easy_entries, config_map):
out_entries, errors = plugin.run(easy_entries, config_map, '', HOOK_REGISTRY) out_entries, errors = plugin.run(easy_entries, config_map, '', HOOK_REGISTRY)