plugin: User configuration is passed to hooks on initialization.
This commit is contained in:
parent
84d8adb7f6
commit
0d370c445b
8 changed files with 62 additions and 9 deletions
|
@ -31,9 +31,21 @@ class Error(Exception):
|
||||||
source=self.source,
|
source=self.source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _fill_source(self, source, filename='conservancy_beancount', lineno=0):
|
||||||
|
source.setdefault('filename', filename)
|
||||||
|
source.setdefault('lineno', lineno)
|
||||||
|
|
||||||
|
|
||||||
Iter = Iterable[Error]
|
Iter = Iterable[Error]
|
||||||
|
|
||||||
|
class ConfigurationError(Error):
|
||||||
|
def __init__(self, message, entry=None, source=None):
|
||||||
|
if source is None:
|
||||||
|
source = {}
|
||||||
|
self._fill_source(source)
|
||||||
|
super().__init__(message, entry, source)
|
||||||
|
|
||||||
|
|
||||||
class InvalidMetadataError(Error):
|
class InvalidMetadataError(Error):
|
||||||
def __init__(self, txn, post, key, value=None, source=None):
|
def __init__(self, txn, post, key, value=None, source=None):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
|
@ -32,6 +32,7 @@ from ..beancount_types import (
|
||||||
ALL_DIRECTIVES,
|
ALL_DIRECTIVES,
|
||||||
Directive,
|
Directive,
|
||||||
)
|
)
|
||||||
|
from .. import config as configmod
|
||||||
from .core import (
|
from .core import (
|
||||||
Hook,
|
Hook,
|
||||||
HookName,
|
HookName,
|
||||||
|
@ -100,8 +101,14 @@ def run(
|
||||||
) -> Tuple[List[Directive], List[Error]]:
|
) -> Tuple[List[Directive], List[Error]]:
|
||||||
errors: List[Error] = []
|
errors: List[Error] = []
|
||||||
hooks: Dict[HookName, List[Hook]] = {}
|
hooks: Dict[HookName, List[Hook]] = {}
|
||||||
|
user_config = configmod.Config()
|
||||||
for key, hook_type in hook_registry.group_by_directive(config):
|
for key, hook_type in hook_registry.group_by_directive(config):
|
||||||
hooks.setdefault(key, []).append(hook_type())
|
try:
|
||||||
|
hook = hook_type(user_config)
|
||||||
|
except Error as error:
|
||||||
|
errors.append(error)
|
||||||
|
else:
|
||||||
|
hooks.setdefault(key, []).append(hook)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
entry_type = type(entry).__name__
|
entry_type = type(entry).__name__
|
||||||
for hook in hooks[entry_type]:
|
for hook in hooks[entry_type]:
|
||||||
|
|
|
@ -18,6 +18,7 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from .. import config as configmod
|
||||||
from .. import data
|
from .. import data
|
||||||
from .. import errors as errormod
|
from .. import errors as errormod
|
||||||
|
|
||||||
|
@ -60,6 +61,11 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
|
||||||
DIRECTIVE: Type[Directive]
|
DIRECTIVE: Type[Directive]
|
||||||
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
||||||
|
|
||||||
|
def __init__(self, config: configmod.Config) -> None:
|
||||||
|
pass
|
||||||
|
# Subclasses that need configuration should override __init__ to check
|
||||||
|
# and store it.
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self, entry: Entry) -> errormod.Iter: ...
|
def run(self, entry: Entry) -> errormod.Iter: ...
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,8 @@ TEST_KEY = 'expense-allocation'
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def hook():
|
def hook():
|
||||||
return meta_expense_allocation.MetaExpenseAllocation()
|
config = testutil.TestConfig()
|
||||||
|
return meta_expense_allocation.MetaExpenseAllocation(config)
|
||||||
|
|
||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||||
|
|
|
@ -39,7 +39,8 @@ TEST_KEY = 'income-type'
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def hook():
|
def hook():
|
||||||
return meta_income_type.MetaIncomeType()
|
config = testutil.TestConfig()
|
||||||
|
return meta_income_type.MetaIncomeType(config)
|
||||||
|
|
||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||||
|
|
|
@ -51,7 +51,8 @@ TEST_KEY = 'tax-implication'
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def hook():
|
def hook():
|
||||||
return meta_tax_implication.MetaTaxImplication()
|
config = testutil.TestConfig()
|
||||||
|
return meta_tax_implication.MetaTaxImplication(config)
|
||||||
|
|
||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||||
def test_valid_values_on_postings(hook, src_value, set_value):
|
def test_valid_values_on_postings(hook, src_value, set_value):
|
||||||
|
|
|
@ -18,24 +18,39 @@ import pytest
|
||||||
|
|
||||||
from . import testutil
|
from . import testutil
|
||||||
|
|
||||||
from conservancy_beancount import beancount_types, plugin
|
from conservancy_beancount import beancount_types, errors as errormod, plugin
|
||||||
|
|
||||||
HOOK_REGISTRY = plugin.HookRegistry()
|
HOOK_REGISTRY = plugin.HookRegistry()
|
||||||
|
|
||||||
|
class NonError(errormod.Error):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TransactionHook:
|
class TransactionHook:
|
||||||
DIRECTIVE = beancount_types.Transaction
|
DIRECTIVE = beancount_types.Transaction
|
||||||
HOOK_GROUPS = frozenset()
|
HOOK_GROUPS = frozenset()
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
assert False, "something called base class run method"
|
assert False, "something called base class run method"
|
||||||
|
|
||||||
|
|
||||||
|
@HOOK_REGISTRY.add_hook
|
||||||
|
class ConfigurationError(TransactionHook):
|
||||||
|
HOOK_GROUPS = frozenset(['unconfigured'])
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
raise errormod.ConfigurationError("testing error")
|
||||||
|
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
class TransactionError(TransactionHook):
|
class TransactionError(TransactionHook):
|
||||||
HOOK_GROUPS = frozenset(['configured'])
|
HOOK_GROUPS = frozenset(['configured'])
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
return ['txn:{}'.format(id(txn))]
|
return [NonError('txn:{}'.format(id(txn)), txn)]
|
||||||
|
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
|
@ -43,7 +58,8 @@ class PostingError(TransactionHook):
|
||||||
HOOK_GROUPS = frozenset(['configured', 'posting'])
|
HOOK_GROUPS = frozenset(['configured', 'posting'])
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
return ['post:{}'.format(id(post)) for post in txn.postings]
|
return [NonError('post:{}'.format(id(post)), txn)
|
||||||
|
for post in txn.postings]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -65,8 +81,8 @@ def easy_entries():
|
||||||
|
|
||||||
def map_errors(errors):
|
def map_errors(errors):
|
||||||
retval = {}
|
retval = {}
|
||||||
for errkey in errors:
|
for error in errors:
|
||||||
key, _, errid = errkey.partition(':')
|
key, _, errid = error.message.partition(':')
|
||||||
retval.setdefault(key, set()).add(errid)
|
retval.setdefault(key, set()).add(errid)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import beancount.core.amount as bc_amount
|
||||||
import beancount.core.data as bc_data
|
import beancount.core.data as bc_data
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
|
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
|
||||||
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
|
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
|
||||||
|
@ -94,3 +95,11 @@ class Transaction:
|
||||||
else:
|
else:
|
||||||
posting = arg
|
posting = arg
|
||||||
self.postings.append(posting)
|
self.postings.append(posting)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig:
|
||||||
|
def __init__(self, repo_path=None):
|
||||||
|
self.repo_path = None if repo_path is None else Path(repo_path)
|
||||||
|
|
||||||
|
def repository_path(self):
|
||||||
|
return self.repo_path
|
||||||
|
|
Loading…
Add table
Reference in a new issue