diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 9ae955f..f722d3c 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -63,9 +63,6 @@ def easy_entries(): ]), ] -def hook_types(hooks, key): - return {type(hook) for hook in hooks[key]} - def map_errors(errors): retval = {} for errkey in errors: @@ -73,33 +70,20 @@ def map_errors(errors): retval.setdefault(key, set()).add(errid) return retval -def test_registry_all_by_default(): - hook_groups = HOOK_REGISTRY.group_by_directive() - hooks = hook_types(hook_groups, 'Transaction') - assert len(hooks) >= 2 - assert TransactionError in hooks - assert PostingError in hooks - -def test_registry_one_exclude(): - hook_groups = HOOK_REGISTRY.group_by_directive('-posting') - hooks = hook_types(hook_groups, 'Transaction') - assert len(hooks) >= 1 - assert TransactionError in hooks - assert PostingError not in hooks - -def test_registry_exclude_then_include(): - hook_groups = HOOK_REGISTRY.group_by_directive('-configured posting') - hooks = hook_types(hook_groups, 'Transaction') - assert len(hooks) >= 1 - assert TransactionError not in hooks - assert PostingError in hooks - -def test_registry_include_then_exclude(): - hook_groups = HOOK_REGISTRY.group_by_directive('configured -posting') - hooks = hook_types(hook_groups, 'Transaction') - assert len(hooks) >= 1 - assert TransactionError in hooks - assert PostingError not in hooks +@pytest.mark.parametrize('group_str,expected', [ + (None, [TransactionError, PostingError]), + ('', [TransactionError, PostingError]), + ('all', [TransactionError, PostingError]), + ('Transaction', [TransactionError, PostingError]), + ('-posting', [TransactionError]), + ('-configured posting', [PostingError]), + ('configured -posting', [TransactionError]), +]) +def test_registry_group_by_directive(group_str, expected): + args = () if group_str is None else (group_str,) + hook_groups = HOOK_REGISTRY.group_by_directive(*args) + actual = {type(hook) for hook in hook_groups['Transaction']} + assert actual.issuperset(expected) def test_registry_unknown_group_name(): with pytest.raises(ValueError):