plugin: Move hook initialization from HookRegistry to run().
Makes more sense here so run can report errors in hook configuration.
This commit is contained in:
parent
e424173216
commit
84d8adb7f6
2 changed files with 11 additions and 10 deletions
|
@ -22,8 +22,8 @@ from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
@ -62,7 +62,7 @@ class HookRegistry:
|
||||||
for hook_name in hook_names:
|
for hook_name in hook_names:
|
||||||
self.add_hook(getattr(module, hook_name))
|
self.add_hook(getattr(module, hook_name))
|
||||||
|
|
||||||
def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]:
|
def group_by_directive(self, config_str: str='') -> Iterable[Tuple[HookName, Type[Hook]]]:
|
||||||
config_str = config_str.strip()
|
config_str = config_str.strip()
|
||||||
if not config_str:
|
if not config_str:
|
||||||
config_str = 'all'
|
config_str = 'all'
|
||||||
|
@ -82,10 +82,10 @@ class HookRegistry:
|
||||||
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
|
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
|
||||||
else:
|
else:
|
||||||
update_available(update_set)
|
update_available(update_set)
|
||||||
return {
|
for directive in ALL_DIRECTIVES:
|
||||||
t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks]
|
key = directive.__name__
|
||||||
for t in ALL_DIRECTIVES
|
for hook in self.group_name_map[key] & available_hooks:
|
||||||
}
|
yield key, hook
|
||||||
|
|
||||||
|
|
||||||
HOOK_REGISTRY = HookRegistry()
|
HOOK_REGISTRY = HookRegistry()
|
||||||
|
@ -99,7 +99,9 @@ def run(
|
||||||
hook_registry: HookRegistry=HOOK_REGISTRY,
|
hook_registry: HookRegistry=HOOK_REGISTRY,
|
||||||
) -> Tuple[List[Directive], List[Error]]:
|
) -> Tuple[List[Directive], List[Error]]:
|
||||||
errors: List[Error] = []
|
errors: List[Error] = []
|
||||||
hooks = hook_registry.group_by_directive(config)
|
hooks: Dict[HookName, List[Hook]] = {}
|
||||||
|
for key, hook_type in hook_registry.group_by_directive(config):
|
||||||
|
hooks.setdefault(key, []).append(hook_type())
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
entry_type = type(entry).__name__
|
entry_type = type(entry).__name__
|
||||||
for hook in hooks[entry_type]:
|
for hook in hooks[entry_type]:
|
||||||
|
|
|
@ -81,15 +81,14 @@ def map_errors(errors):
|
||||||
])
|
])
|
||||||
def test_registry_group_by_directive(group_str, expected):
|
def test_registry_group_by_directive(group_str, expected):
|
||||||
args = () if group_str is None else (group_str,)
|
args = () if group_str is None else (group_str,)
|
||||||
hook_groups = HOOK_REGISTRY.group_by_directive(*args)
|
actual = {hook for _, hook in HOOK_REGISTRY.group_by_directive(*args)}
|
||||||
actual = {type(hook) for hook in hook_groups['Transaction']}
|
|
||||||
assert actual.issuperset(expected)
|
assert actual.issuperset(expected)
|
||||||
if len(expected) == 1:
|
if len(expected) == 1:
|
||||||
assert not (TransactionError in actual and PostingError in actual)
|
assert not (TransactionError in actual and PostingError in actual)
|
||||||
|
|
||||||
def test_registry_unknown_group_name():
|
def test_registry_unknown_group_name():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
HOOK_REGISTRY.group_by_directive('UnKnownTestGroup')
|
next(HOOK_REGISTRY.group_by_directive('UnKnownTestGroup'))
|
||||||
|
|
||||||
def test_run_with_multiple_hooks(easy_entries, config_map):
|
def test_run_with_multiple_hooks(easy_entries, config_map):
|
||||||
out_entries, errors = plugin.run(easy_entries, config_map, '', HOOK_REGISTRY)
|
out_entries, errors = plugin.run(easy_entries, config_map, '', HOOK_REGISTRY)
|
||||||
|
|
Loading…
Add table
Reference in a new issue