plugin: User configuration is passed to hooks on initialization.
This commit is contained in:
parent
84d8adb7f6
commit
0d370c445b
8 changed files with 62 additions and 9 deletions
|
@ -31,9 +31,21 @@ class Error(Exception):
|
|||
source=self.source,
|
||||
)
|
||||
|
||||
def _fill_source(self, source, filename='conservancy_beancount', lineno=0):
|
||||
source.setdefault('filename', filename)
|
||||
source.setdefault('lineno', lineno)
|
||||
|
||||
|
||||
Iter = Iterable[Error]
|
||||
|
||||
class ConfigurationError(Error):
|
||||
def __init__(self, message, entry=None, source=None):
|
||||
if source is None:
|
||||
source = {}
|
||||
self._fill_source(source)
|
||||
super().__init__(message, entry, source)
|
||||
|
||||
|
||||
class InvalidMetadataError(Error):
|
||||
def __init__(self, txn, post, key, value=None, source=None):
|
||||
if value is None:
|
||||
|
|
|
@ -32,6 +32,7 @@ from ..beancount_types import (
|
|||
ALL_DIRECTIVES,
|
||||
Directive,
|
||||
)
|
||||
from .. import config as configmod
|
||||
from .core import (
|
||||
Hook,
|
||||
HookName,
|
||||
|
@ -100,8 +101,14 @@ def run(
|
|||
) -> Tuple[List[Directive], List[Error]]:
|
||||
errors: List[Error] = []
|
||||
hooks: Dict[HookName, List[Hook]] = {}
|
||||
user_config = configmod.Config()
|
||||
for key, hook_type in hook_registry.group_by_directive(config):
|
||||
hooks.setdefault(key, []).append(hook_type())
|
||||
try:
|
||||
hook = hook_type(user_config)
|
||||
except Error as error:
|
||||
errors.append(error)
|
||||
else:
|
||||
hooks.setdefault(key, []).append(hook)
|
||||
for entry in entries:
|
||||
entry_type = type(entry).__name__
|
||||
for hook in hooks[entry_type]:
|
||||
|
|
|
@ -18,6 +18,7 @@ import abc
|
|||
import datetime
|
||||
import re
|
||||
|
||||
from .. import config as configmod
|
||||
from .. import data
|
||||
from .. import errors as errormod
|
||||
|
||||
|
@ -60,6 +61,11 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
|
|||
DIRECTIVE: Type[Directive]
|
||||
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
||||
|
||||
def __init__(self, config: configmod.Config) -> None:
|
||||
pass
|
||||
# Subclasses that need configuration should override __init__ to check
|
||||
# and store it.
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, entry: Entry) -> errormod.Iter: ...
|
||||
|
||||
|
|
|
@ -39,7 +39,8 @@ TEST_KEY = 'expense-allocation'
|
|||
|
||||
@pytest.fixture(scope='module')
|
||||
def hook():
|
||||
return meta_expense_allocation.MetaExpenseAllocation()
|
||||
config = testutil.TestConfig()
|
||||
return meta_expense_allocation.MetaExpenseAllocation(config)
|
||||
|
||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||
|
|
|
@ -39,7 +39,8 @@ TEST_KEY = 'income-type'
|
|||
|
||||
@pytest.fixture(scope='module')
|
||||
def hook():
|
||||
return meta_income_type.MetaIncomeType()
|
||||
config = testutil.TestConfig()
|
||||
return meta_income_type.MetaIncomeType(config)
|
||||
|
||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||
|
|
|
@ -51,7 +51,8 @@ TEST_KEY = 'tax-implication'
|
|||
|
||||
@pytest.fixture(scope='module')
|
||||
def hook():
|
||||
return meta_tax_implication.MetaTaxImplication()
|
||||
config = testutil.TestConfig()
|
||||
return meta_tax_implication.MetaTaxImplication(config)
|
||||
|
||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||
|
|
|
@ -18,24 +18,39 @@ import pytest
|
|||
|
||||
from . import testutil
|
||||
|
||||
from conservancy_beancount import beancount_types, plugin
|
||||
from conservancy_beancount import beancount_types, errors as errormod, plugin
|
||||
|
||||
HOOK_REGISTRY = plugin.HookRegistry()
|
||||
|
||||
class NonError(errormod.Error):
|
||||
pass
|
||||
|
||||
|
||||
class TransactionHook:
|
||||
DIRECTIVE = beancount_types.Transaction
|
||||
HOOK_GROUPS = frozenset()
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def run(self, txn):
|
||||
assert False, "something called base class run method"
|
||||
|
||||
|
||||
@HOOK_REGISTRY.add_hook
|
||||
class ConfigurationError(TransactionHook):
|
||||
HOOK_GROUPS = frozenset(['unconfigured'])
|
||||
|
||||
def __init__(self, config):
|
||||
raise errormod.ConfigurationError("testing error")
|
||||
|
||||
|
||||
@HOOK_REGISTRY.add_hook
|
||||
class TransactionError(TransactionHook):
|
||||
HOOK_GROUPS = frozenset(['configured'])
|
||||
|
||||
def run(self, txn):
|
||||
return ['txn:{}'.format(id(txn))]
|
||||
return [NonError('txn:{}'.format(id(txn)), txn)]
|
||||
|
||||
|
||||
@HOOK_REGISTRY.add_hook
|
||||
|
@ -43,7 +58,8 @@ class PostingError(TransactionHook):
|
|||
HOOK_GROUPS = frozenset(['configured', 'posting'])
|
||||
|
||||
def run(self, txn):
|
||||
return ['post:{}'.format(id(post)) for post in txn.postings]
|
||||
return [NonError('post:{}'.format(id(post)), txn)
|
||||
for post in txn.postings]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -65,8 +81,8 @@ def easy_entries():
|
|||
|
||||
def map_errors(errors):
|
||||
retval = {}
|
||||
for errkey in errors:
|
||||
key, _, errid = errkey.partition(':')
|
||||
for error in errors:
|
||||
key, _, errid = error.message.partition(':')
|
||||
retval.setdefault(key, set()).add(errid)
|
||||
return retval
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import beancount.core.amount as bc_amount
|
|||
import beancount.core.data as bc_data
|
||||
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
|
||||
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
|
||||
|
@ -94,3 +95,11 @@ class Transaction:
|
|||
else:
|
||||
posting = arg
|
||||
self.postings.append(posting)
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def __init__(self, repo_path=None):
|
||||
self.repo_path = None if repo_path is None else Path(repo_path)
|
||||
|
||||
def repository_path(self):
|
||||
return self.repo_path
|
||||
|
|
Loading…
Reference in a new issue