plugin: User configuration is passed to hooks on initialization.

This commit is contained in:
Brett Smith 2020-03-19 17:23:27 -04:00
parent 84d8adb7f6
commit 0d370c445b
8 changed files with 62 additions and 9 deletions

View file

@ -31,9 +31,21 @@ class Error(Exception):
source=self.source, source=self.source,
) )
def _fill_source(self, source, filename='conservancy_beancount', lineno=0):
source.setdefault('filename', filename)
source.setdefault('lineno', lineno)
Iter = Iterable[Error] Iter = Iterable[Error]
class ConfigurationError(Error):
def __init__(self, message, entry=None, source=None):
if source is None:
source = {}
self._fill_source(source)
super().__init__(message, entry, source)
class InvalidMetadataError(Error): class InvalidMetadataError(Error):
def __init__(self, txn, post, key, value=None, source=None): def __init__(self, txn, post, key, value=None, source=None):
if value is None: if value is None:

View file

@ -32,6 +32,7 @@ from ..beancount_types import (
ALL_DIRECTIVES, ALL_DIRECTIVES,
Directive, Directive,
) )
from .. import config as configmod
from .core import ( from .core import (
Hook, Hook,
HookName, HookName,
@ -100,8 +101,14 @@ def run(
) -> Tuple[List[Directive], List[Error]]: ) -> Tuple[List[Directive], List[Error]]:
errors: List[Error] = [] errors: List[Error] = []
hooks: Dict[HookName, List[Hook]] = {} hooks: Dict[HookName, List[Hook]] = {}
user_config = configmod.Config()
for key, hook_type in hook_registry.group_by_directive(config): for key, hook_type in hook_registry.group_by_directive(config):
hooks.setdefault(key, []).append(hook_type()) try:
hook = hook_type(user_config)
except Error as error:
errors.append(error)
else:
hooks.setdefault(key, []).append(hook)
for entry in entries: for entry in entries:
entry_type = type(entry).__name__ entry_type = type(entry).__name__
for hook in hooks[entry_type]: for hook in hooks[entry_type]:

View file

@ -18,6 +18,7 @@ import abc
import datetime import datetime
import re import re
from .. import config as configmod
from .. import data from .. import data
from .. import errors as errormod from .. import errors as errormod
@ -60,6 +61,11 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
DIRECTIVE: Type[Directive] DIRECTIVE: Type[Directive]
HOOK_GROUPS: FrozenSet[HookName] = frozenset() HOOK_GROUPS: FrozenSet[HookName] = frozenset()
def __init__(self, config: configmod.Config) -> None:
pass
# Subclasses that need configuration should override __init__ to check
# and store it.
@abc.abstractmethod @abc.abstractmethod
def run(self, entry: Entry) -> errormod.Iter: ... def run(self, entry: Entry) -> errormod.Iter: ...

View file

@ -39,7 +39,8 @@ TEST_KEY = 'expense-allocation'
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def hook(): def hook():
return meta_expense_allocation.MetaExpenseAllocation() config = testutil.TestConfig()
return meta_expense_allocation.MetaExpenseAllocation(config)
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items()) @pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
def test_valid_values_on_postings(hook, src_value, set_value): def test_valid_values_on_postings(hook, src_value, set_value):

View file

@ -39,7 +39,8 @@ TEST_KEY = 'income-type'
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def hook(): def hook():
return meta_income_type.MetaIncomeType() config = testutil.TestConfig()
return meta_income_type.MetaIncomeType(config)
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items()) @pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
def test_valid_values_on_postings(hook, src_value, set_value): def test_valid_values_on_postings(hook, src_value, set_value):

View file

@ -51,7 +51,8 @@ TEST_KEY = 'tax-implication'
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def hook(): def hook():
return meta_tax_implication.MetaTaxImplication() config = testutil.TestConfig()
return meta_tax_implication.MetaTaxImplication(config)
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items()) @pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
def test_valid_values_on_postings(hook, src_value, set_value): def test_valid_values_on_postings(hook, src_value, set_value):

View file

@ -18,24 +18,39 @@ import pytest
from . import testutil from . import testutil
from conservancy_beancount import beancount_types, plugin from conservancy_beancount import beancount_types, errors as errormod, plugin
HOOK_REGISTRY = plugin.HookRegistry() HOOK_REGISTRY = plugin.HookRegistry()
class NonError(errormod.Error):
pass
class TransactionHook: class TransactionHook:
DIRECTIVE = beancount_types.Transaction DIRECTIVE = beancount_types.Transaction
HOOK_GROUPS = frozenset() HOOK_GROUPS = frozenset()
def __init__(self, config):
self.config = config
def run(self, txn): def run(self, txn):
assert False, "something called base class run method" assert False, "something called base class run method"
@HOOK_REGISTRY.add_hook
class ConfigurationError(TransactionHook):
HOOK_GROUPS = frozenset(['unconfigured'])
def __init__(self, config):
raise errormod.ConfigurationError("testing error")
@HOOK_REGISTRY.add_hook @HOOK_REGISTRY.add_hook
class TransactionError(TransactionHook): class TransactionError(TransactionHook):
HOOK_GROUPS = frozenset(['configured']) HOOK_GROUPS = frozenset(['configured'])
def run(self, txn): def run(self, txn):
return ['txn:{}'.format(id(txn))] return [NonError('txn:{}'.format(id(txn)), txn)]
@HOOK_REGISTRY.add_hook @HOOK_REGISTRY.add_hook
@ -43,7 +58,8 @@ class PostingError(TransactionHook):
HOOK_GROUPS = frozenset(['configured', 'posting']) HOOK_GROUPS = frozenset(['configured', 'posting'])
def run(self, txn): def run(self, txn):
return ['post:{}'.format(id(post)) for post in txn.postings] return [NonError('post:{}'.format(id(post)), txn)
for post in txn.postings]
@pytest.fixture @pytest.fixture
@ -65,8 +81,8 @@ def easy_entries():
def map_errors(errors): def map_errors(errors):
retval = {} retval = {}
for errkey in errors: for error in errors:
key, _, errid = errkey.partition(':') key, _, errid = error.message.partition(':')
retval.setdefault(key, set()).add(errid) retval.setdefault(key, set()).add(errid)
return retval return retval

View file

@ -20,6 +20,7 @@ import beancount.core.amount as bc_amount
import beancount.core.data as bc_data import beancount.core.data as bc_data
from decimal import Decimal from decimal import Decimal
from pathlib import Path
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30) EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99) FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
@ -94,3 +95,11 @@ class Transaction:
else: else:
posting = arg posting = arg
self.postings.append(posting) self.postings.append(posting)
class TestConfig:
def __init__(self, repo_path=None):
self.repo_path = None if repo_path is None else Path(repo_path)
def repository_path(self):
return self.repo_path