plugin: User configuration is passed to hooks on initialization.
This commit is contained in:
		
							parent
							
								
									84d8adb7f6
								
							
						
					
					
						commit
						0d370c445b
					
				
					 8 changed files with 62 additions and 9 deletions
				
			
		| 
						 | 
					@ -31,9 +31,21 @@ class Error(Exception):
 | 
				
			||||||
            source=self.source,
 | 
					            source=self.source,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _fill_source(self, source, filename='conservancy_beancount', lineno=0):
 | 
				
			||||||
 | 
					        source.setdefault('filename', filename)
 | 
				
			||||||
 | 
					        source.setdefault('lineno', lineno)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Iter = Iterable[Error]
 | 
					Iter = Iterable[Error]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConfigurationError(Error):
 | 
				
			||||||
 | 
					    def __init__(self, message, entry=None, source=None):
 | 
				
			||||||
 | 
					        if source is None:
 | 
				
			||||||
 | 
					            source = {}
 | 
				
			||||||
 | 
					        self._fill_source(source)
 | 
				
			||||||
 | 
					        super().__init__(message, entry, source)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InvalidMetadataError(Error):
 | 
					class InvalidMetadataError(Error):
 | 
				
			||||||
    def __init__(self, txn, post, key, value=None, source=None):
 | 
					    def __init__(self, txn, post, key, value=None, source=None):
 | 
				
			||||||
        if value is None:
 | 
					        if value is None:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,7 @@ from ..beancount_types import (
 | 
				
			||||||
    ALL_DIRECTIVES,
 | 
					    ALL_DIRECTIVES,
 | 
				
			||||||
    Directive,
 | 
					    Directive,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from .. import config as configmod
 | 
				
			||||||
from .core import (
 | 
					from .core import (
 | 
				
			||||||
    Hook,
 | 
					    Hook,
 | 
				
			||||||
    HookName,
 | 
					    HookName,
 | 
				
			||||||
| 
						 | 
					@ -100,8 +101,14 @@ def run(
 | 
				
			||||||
) -> Tuple[List[Directive], List[Error]]:
 | 
					) -> Tuple[List[Directive], List[Error]]:
 | 
				
			||||||
    errors: List[Error] = []
 | 
					    errors: List[Error] = []
 | 
				
			||||||
    hooks: Dict[HookName, List[Hook]] = {}
 | 
					    hooks: Dict[HookName, List[Hook]] = {}
 | 
				
			||||||
 | 
					    user_config = configmod.Config()
 | 
				
			||||||
    for key, hook_type in hook_registry.group_by_directive(config):
 | 
					    for key, hook_type in hook_registry.group_by_directive(config):
 | 
				
			||||||
        hooks.setdefault(key, []).append(hook_type())
 | 
					        try:
 | 
				
			||||||
 | 
					            hook = hook_type(user_config)
 | 
				
			||||||
 | 
					        except Error as error:
 | 
				
			||||||
 | 
					            errors.append(error)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            hooks.setdefault(key, []).append(hook)
 | 
				
			||||||
    for entry in entries:
 | 
					    for entry in entries:
 | 
				
			||||||
        entry_type = type(entry).__name__
 | 
					        entry_type = type(entry).__name__
 | 
				
			||||||
        for hook in hooks[entry_type]:
 | 
					        for hook in hooks[entry_type]:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@ import abc
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .. import config as configmod
 | 
				
			||||||
from .. import data
 | 
					from .. import data
 | 
				
			||||||
from .. import errors as errormod
 | 
					from .. import errors as errormod
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,6 +61,11 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
 | 
				
			||||||
    DIRECTIVE: Type[Directive]
 | 
					    DIRECTIVE: Type[Directive]
 | 
				
			||||||
    HOOK_GROUPS: FrozenSet[HookName] = frozenset()
 | 
					    HOOK_GROUPS: FrozenSet[HookName] = frozenset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config: configmod.Config) -> None:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					        # Subclasses that need configuration should override __init__ to check
 | 
				
			||||||
 | 
					        # and store it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def run(self, entry: Entry) -> errormod.Iter: ...
 | 
					    def run(self, entry: Entry) -> errormod.Iter: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +39,8 @@ TEST_KEY = 'expense-allocation'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture(scope='module')
 | 
					@pytest.fixture(scope='module')
 | 
				
			||||||
def hook():
 | 
					def hook():
 | 
				
			||||||
    return meta_expense_allocation.MetaExpenseAllocation()
 | 
					    config = testutil.TestConfig()
 | 
				
			||||||
 | 
					    return meta_expense_allocation.MetaExpenseAllocation(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
					@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
				
			||||||
def test_valid_values_on_postings(hook, src_value, set_value):
 | 
					def test_valid_values_on_postings(hook, src_value, set_value):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +39,8 @@ TEST_KEY = 'income-type'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture(scope='module')
 | 
					@pytest.fixture(scope='module')
 | 
				
			||||||
def hook():
 | 
					def hook():
 | 
				
			||||||
    return meta_income_type.MetaIncomeType()
 | 
					    config = testutil.TestConfig()
 | 
				
			||||||
 | 
					    return meta_income_type.MetaIncomeType(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
					@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
				
			||||||
def test_valid_values_on_postings(hook, src_value, set_value):
 | 
					def test_valid_values_on_postings(hook, src_value, set_value):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,7 +51,8 @@ TEST_KEY = 'tax-implication'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture(scope='module')
 | 
					@pytest.fixture(scope='module')
 | 
				
			||||||
def hook():
 | 
					def hook():
 | 
				
			||||||
    return meta_tax_implication.MetaTaxImplication()
 | 
					    config = testutil.TestConfig()
 | 
				
			||||||
 | 
					    return meta_tax_implication.MetaTaxImplication(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
					@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
 | 
				
			||||||
def test_valid_values_on_postings(hook, src_value, set_value):
 | 
					def test_valid_values_on_postings(hook, src_value, set_value):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,24 +18,39 @@ import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import testutil
 | 
					from . import testutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from conservancy_beancount import beancount_types, plugin
 | 
					from conservancy_beancount import beancount_types, errors as errormod, plugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HOOK_REGISTRY = plugin.HookRegistry()
 | 
					HOOK_REGISTRY = plugin.HookRegistry()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NonError(errormod.Error):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TransactionHook:
 | 
					class TransactionHook:
 | 
				
			||||||
    DIRECTIVE = beancount_types.Transaction
 | 
					    DIRECTIVE = beancount_types.Transaction
 | 
				
			||||||
    HOOK_GROUPS = frozenset()
 | 
					    HOOK_GROUPS = frozenset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        self.config = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn):
 | 
					    def run(self, txn):
 | 
				
			||||||
        assert False, "something called base class run method"
 | 
					        assert False, "something called base class run method"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@HOOK_REGISTRY.add_hook
 | 
				
			||||||
 | 
					class ConfigurationError(TransactionHook):
 | 
				
			||||||
 | 
					    HOOK_GROUPS = frozenset(['unconfigured'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, config):
 | 
				
			||||||
 | 
					        raise errormod.ConfigurationError("testing error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@HOOK_REGISTRY.add_hook
 | 
					@HOOK_REGISTRY.add_hook
 | 
				
			||||||
class TransactionError(TransactionHook):
 | 
					class TransactionError(TransactionHook):
 | 
				
			||||||
    HOOK_GROUPS = frozenset(['configured'])
 | 
					    HOOK_GROUPS = frozenset(['configured'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn):
 | 
					    def run(self, txn):
 | 
				
			||||||
        return ['txn:{}'.format(id(txn))]
 | 
					        return [NonError('txn:{}'.format(id(txn)), txn)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@HOOK_REGISTRY.add_hook
 | 
					@HOOK_REGISTRY.add_hook
 | 
				
			||||||
| 
						 | 
					@ -43,7 +58,8 @@ class PostingError(TransactionHook):
 | 
				
			||||||
    HOOK_GROUPS = frozenset(['configured', 'posting'])
 | 
					    HOOK_GROUPS = frozenset(['configured', 'posting'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn):
 | 
					    def run(self, txn):
 | 
				
			||||||
        return ['post:{}'.format(id(post)) for post in txn.postings]
 | 
					        return [NonError('post:{}'.format(id(post)), txn)
 | 
				
			||||||
 | 
					                for post in txn.postings]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
| 
						 | 
					@ -65,8 +81,8 @@ def easy_entries():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def map_errors(errors):
 | 
					def map_errors(errors):
 | 
				
			||||||
    retval = {}
 | 
					    retval = {}
 | 
				
			||||||
    for errkey in errors:
 | 
					    for error in errors:
 | 
				
			||||||
        key, _, errid = errkey.partition(':')
 | 
					        key, _, errid = error.message.partition(':')
 | 
				
			||||||
        retval.setdefault(key, set()).add(errid)
 | 
					        retval.setdefault(key, set()).add(errid)
 | 
				
			||||||
    return retval
 | 
					    return retval
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@ import beancount.core.amount as bc_amount
 | 
				
			||||||
import beancount.core.data as bc_data
 | 
					import beancount.core.data as bc_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from decimal import Decimal
 | 
					from decimal import Decimal
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
 | 
					EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
 | 
				
			||||||
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
 | 
					FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
 | 
				
			||||||
| 
						 | 
					@ -94,3 +95,11 @@ class Transaction:
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            posting = arg
 | 
					            posting = arg
 | 
				
			||||||
        self.postings.append(posting)
 | 
					        self.postings.append(posting)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestConfig:
 | 
				
			||||||
 | 
					    def __init__(self, repo_path=None):
 | 
				
			||||||
 | 
					        self.repo_path = None if repo_path is None else Path(repo_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def repository_path(self):
 | 
				
			||||||
 | 
					        return self.repo_path
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue