plugin: Transform posting hooks into transaction hooks.
I feel like posting hooks a case of premature optimization in early development. This approach reduces the number of special cases in the code and allows us to more strongly reason about hooks in the type system.
This commit is contained in:
parent
c9ff4ab746
commit
a41feb94b3
9 changed files with 180 additions and 122 deletions
|
@ -18,9 +18,12 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import beancount.core.data as bc_data
|
import beancount.core.data as bc_data
|
||||||
|
from .plugin import errors
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
FrozenSet,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -30,7 +33,8 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
Account = bc_data.Account
|
Account = bc_data.Account
|
||||||
HookName = str
|
Error = errors._BaseError
|
||||||
|
ErrorIter = Iterable[Error]
|
||||||
MetaKey = str
|
MetaKey = str
|
||||||
MetaValue = Any
|
MetaValue = Any
|
||||||
MetaValueEnum = str
|
MetaValueEnum = str
|
||||||
|
@ -56,3 +60,8 @@ class Transaction(Directive):
|
||||||
tags: Set
|
tags: Set
|
||||||
links: Set
|
links: Set
|
||||||
postings: List[Posting]
|
postings: List[Posting]
|
||||||
|
|
||||||
|
|
||||||
|
ALL_DIRECTIVES: FrozenSet[Type[Directive]] = frozenset([
|
||||||
|
Transaction,
|
||||||
|
])
|
||||||
|
|
|
@ -18,31 +18,40 @@ import importlib
|
||||||
|
|
||||||
import beancount.core.data as bc_data
|
import beancount.core.data as bc_data
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
from .._typing import (
|
||||||
|
ALL_DIRECTIVES,
|
||||||
|
Directive,
|
||||||
|
Error,
|
||||||
|
)
|
||||||
|
from .core import (
|
||||||
|
Hook,
|
||||||
|
HookName,
|
||||||
|
)
|
||||||
|
|
||||||
__plugins__ = ['run']
|
__plugins__ = ['run']
|
||||||
|
|
||||||
class HookRegistry:
|
class HookRegistry:
|
||||||
DIRECTIVES = frozenset([
|
def __init__(self) -> None:
|
||||||
*(cls.__name__ for cls in bc_data.ALL_DIRECTIVES),
|
self.group_name_map: Dict[HookName, Set[Type[Hook]]] = {
|
||||||
'Posting',
|
t.__name__: set() for t in ALL_DIRECTIVES
|
||||||
])
|
}
|
||||||
|
self.group_name_map['all'] = set()
|
||||||
|
|
||||||
def __init__(self):
|
def add_hook(self, hook_cls: Type[Hook]) -> Type[Hook]:
|
||||||
self.group_hooks_map = {key: set() for key in self.DIRECTIVES}
|
self.group_name_map['all'].add(hook_cls)
|
||||||
|
self.group_name_map[hook_cls.DIRECTIVE.__name__].add(hook_cls)
|
||||||
def add_hook(self, hook_cls):
|
for key in hook_cls.HOOK_GROUPS:
|
||||||
hook_groups = list(hook_cls.HOOK_GROUPS)
|
self.group_name_map.setdefault(key, set()).add(hook_cls)
|
||||||
assert self.DIRECTIVES.intersection(hook_groups)
|
|
||||||
hook_groups.append('all')
|
|
||||||
for name_attr in ['HOOK_NAME', 'METADATA_KEY', '__name__']:
|
|
||||||
try:
|
|
||||||
hook_name = getattr(hook_cls, name_attr)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
hook_groups.append(hook_name)
|
|
||||||
break
|
|
||||||
for key in hook_groups:
|
|
||||||
self.group_hooks_map.setdefault(key, set()).add(hook_cls)
|
|
||||||
return hook_cls # to allow use as a decorator
|
return hook_cls # to allow use as a decorator
|
||||||
|
|
||||||
def import_hooks(self, mod_name, *hook_names, package=__module__):
|
def import_hooks(self, mod_name, *hook_names, package=__module__):
|
||||||
|
@ -50,13 +59,13 @@ class HookRegistry:
|
||||||
for hook_name in hook_names:
|
for hook_name in hook_names:
|
||||||
self.add_hook(getattr(module, hook_name))
|
self.add_hook(getattr(module, hook_name))
|
||||||
|
|
||||||
def group_by_directive(self, config_str=''):
|
def group_by_directive(self, config_str: str='') -> Mapping[HookName, List[Hook]]:
|
||||||
config_str = config_str.strip()
|
config_str = config_str.strip()
|
||||||
if not config_str:
|
if not config_str:
|
||||||
config_str = 'all'
|
config_str = 'all'
|
||||||
elif config_str.startswith('-'):
|
elif config_str.startswith('-'):
|
||||||
config_str = 'all ' + config_str
|
config_str = 'all ' + config_str
|
||||||
available_hooks = set()
|
available_hooks: Set[Type[Hook]] = set()
|
||||||
for token in config_str.split():
|
for token in config_str.split():
|
||||||
if token.startswith('-'):
|
if token.startswith('-'):
|
||||||
update_available = available_hooks.difference_update
|
update_available = available_hooks.difference_update
|
||||||
|
@ -65,29 +74,32 @@ class HookRegistry:
|
||||||
update_available = available_hooks.update
|
update_available = available_hooks.update
|
||||||
key = token
|
key = token
|
||||||
try:
|
try:
|
||||||
update_set = self.group_hooks_map[key]
|
update_set = self.group_name_map[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
|
raise ValueError("configuration refers to unknown hooks {!r}".format(key)) from None
|
||||||
else:
|
else:
|
||||||
update_available(update_set)
|
update_available(update_set)
|
||||||
return {key: [hook() for hook in self.group_hooks_map[key] & available_hooks]
|
return {
|
||||||
for key in self.DIRECTIVES}
|
t.__name__: [hook() for hook in self.group_name_map[t.__name__] & available_hooks]
|
||||||
|
for t in ALL_DIRECTIVES
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
HOOK_REGISTRY = HookRegistry()
|
HOOK_REGISTRY = HookRegistry()
|
||||||
HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
|
HOOK_REGISTRY.import_hooks('.meta_expense_allocation', 'MetaExpenseAllocation')
|
||||||
HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
|
HOOK_REGISTRY.import_hooks('.meta_tax_implication', 'MetaTaxImplication')
|
||||||
|
|
||||||
def run(entries, options_map, config='', hook_registry=HOOK_REGISTRY):
|
def run(
|
||||||
errors = []
|
entries: List[Directive],
|
||||||
|
options_map: Dict[str, Any],
|
||||||
|
config: str='',
|
||||||
|
hook_registry: HookRegistry=HOOK_REGISTRY,
|
||||||
|
) -> Tuple[List[Directive], List[Error]]:
|
||||||
|
errors: List[Error] = []
|
||||||
hooks = hook_registry.group_by_directive(config)
|
hooks = hook_registry.group_by_directive(config)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
entry_type = type(entry).__name__
|
entry_type = type(entry).__name__
|
||||||
for hook in hooks[entry_type]:
|
for hook in hooks[entry_type]:
|
||||||
errors.extend(hook.run(entry))
|
errors.extend(hook.run(entry))
|
||||||
if entry_type == 'Transaction':
|
|
||||||
for index, post in enumerate(entry.postings):
|
|
||||||
for hook in hooks['Posting']:
|
|
||||||
errors.extend(hook.run(entry, post, index))
|
|
||||||
return entries, errors
|
return entries, errors
|
||||||
|
|
||||||
|
|
|
@ -14,36 +14,37 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import abc
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from . import errors as errormod
|
from . import errors as errormod
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
FrozenSet,
|
||||||
Any,
|
|
||||||
ClassVar,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
from .._typing import (
|
from .._typing import (
|
||||||
Account,
|
Account,
|
||||||
HookName,
|
Directive,
|
||||||
|
Error,
|
||||||
|
ErrorIter,
|
||||||
LessComparable,
|
LessComparable,
|
||||||
MetaKey,
|
MetaKey,
|
||||||
MetaValue,
|
MetaValue,
|
||||||
MetaValueEnum,
|
MetaValueEnum,
|
||||||
Posting,
|
Posting,
|
||||||
Transaction,
|
Transaction,
|
||||||
|
Type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CONSTANTS
|
||||||
|
|
||||||
# I expect these will become configurable in the future, which is why I'm
|
# I expect these will become configurable in the future, which is why I'm
|
||||||
# keeping them outside of a class, but for now constants will do.
|
# keeping them outside of a class, but for now constants will do.
|
||||||
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
|
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
|
||||||
|
@ -51,8 +52,27 @@ DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1)
|
||||||
# dates past the far end of the range.
|
# dates past the far end of the range.
|
||||||
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
|
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1)
|
||||||
|
|
||||||
CT = TypeVar('CT', bound=LessComparable)
|
### TYPE DEFINITIONS
|
||||||
|
|
||||||
|
HookName = str
|
||||||
|
|
||||||
|
Entry = TypeVar('Entry', bound=Directive)
|
||||||
|
class Hook(Generic[Entry], metaclass=abc.ABCMeta):
|
||||||
|
DIRECTIVE: Type[Directive]
|
||||||
|
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def run(self, entry: Entry) -> ErrorIter: ...
|
||||||
|
|
||||||
|
def __init_subclass__(cls):
|
||||||
|
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
|
||||||
|
|
||||||
|
|
||||||
|
TransactionHook = Hook[Transaction]
|
||||||
|
|
||||||
|
### HELPER CLASSES
|
||||||
|
|
||||||
|
CT = TypeVar('CT', bound=LessComparable)
|
||||||
class _GenericRange(Generic[CT]):
|
class _GenericRange(Generic[CT]):
|
||||||
"""Convenience class to check whether a value is within a range.
|
"""Convenience class to check whether a value is within a range.
|
||||||
|
|
||||||
|
@ -143,24 +163,14 @@ class MetadataEnum:
|
||||||
return self[default_key]
|
return self[default_key]
|
||||||
|
|
||||||
|
|
||||||
class PostingChecker:
|
### HOOK SUBCLASSES
|
||||||
"""Base class to normalize posting metadata from an enum."""
|
|
||||||
# This class provides basic functionality to filter postings, normalize
|
|
||||||
# metadata values, and set default values.
|
|
||||||
# Subclasses should set:
|
|
||||||
# * METADATA_KEY: A string with the name of the metadata key to normalize.
|
|
||||||
# * ACCOUNTS: Only check postings that match these account names.
|
|
||||||
# Can be a tuple of account prefix strings, or a regexp.
|
|
||||||
# * VALUES_ENUM: A MetadataEnum with allowed values and aliases.
|
|
||||||
# Subclasses may wish to override _default_value and _should_check.
|
|
||||||
# See below.
|
|
||||||
|
|
||||||
METADATA_KEY: ClassVar[MetaKey]
|
class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
|
||||||
VALUES_ENUM: MetadataEnum
|
|
||||||
HOOK_GROUPS: AbstractSet[HookName] = frozenset(['Posting', 'metadata'])
|
|
||||||
ACCOUNTS: Union[str, Tuple[Account, ...]] = ('',)
|
|
||||||
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
|
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting'])
|
||||||
|
|
||||||
def _meta_get(self,
|
def _meta_get(self,
|
||||||
txn: Transaction,
|
txn: Transaction,
|
||||||
post: Posting,
|
post: Posting,
|
||||||
|
@ -184,6 +194,34 @@ class PostingChecker:
|
||||||
else:
|
else:
|
||||||
post.meta[key] = value
|
post.meta[key] = value
|
||||||
|
|
||||||
|
def _run_on_txn(self, txn: Transaction) -> bool:
|
||||||
|
return txn.date in self.TXN_DATE_RANGE
|
||||||
|
|
||||||
|
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run(self, txn: Transaction) -> ErrorIter:
|
||||||
|
if self._run_on_txn(txn):
|
||||||
|
for index, post in enumerate(txn.postings):
|
||||||
|
if self._run_on_post(txn, post):
|
||||||
|
yield from self.post_run(txn, post, index)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizePostingMetadataHook(_PostingHook):
|
||||||
|
"""Base class to normalize posting metadata from an enum."""
|
||||||
|
# This class provides basic functionality to filter postings, normalize
|
||||||
|
# metadata values, and set default values.
|
||||||
|
METADATA_KEY: MetaKey
|
||||||
|
VALUES_ENUM: MetadataEnum
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
super().__init_subclass__()
|
||||||
|
cls.METADATA_KEY = cls.VALUES_ENUM.key
|
||||||
|
cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['metadata', cls.METADATA_KEY])
|
||||||
|
|
||||||
# If the posting does not specify METADATA_KEY, the hook calls
|
# If the posting does not specify METADATA_KEY, the hook calls
|
||||||
# _default_value to get a default. This method should either return
|
# _default_value to get a default. This method should either return
|
||||||
# a value string from METADATA_ENUM, or else raise InvalidMetadataError.
|
# a value string from METADATA_ENUM, or else raise InvalidMetadataError.
|
||||||
|
@ -191,35 +229,23 @@ class PostingChecker:
|
||||||
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
||||||
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
||||||
|
|
||||||
# The hook calls _should_check on every posting and only checks postings
|
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
|
||||||
# when the method returns true. This base method checks the transaction
|
|
||||||
# date is in TXN_DATE_RANGE, and the posting account name matches ACCOUNTS.
|
|
||||||
def _should_check(self, txn: Transaction, post: Posting) -> bool:
|
|
||||||
ok = txn.date in self.TXN_DATE_RANGE
|
|
||||||
if isinstance(self.ACCOUNTS, tuple):
|
|
||||||
ok = ok and post.account.startswith(self.ACCOUNTS)
|
|
||||||
else:
|
|
||||||
ok = ok and bool(re.search(self.ACCOUNTS, post.account))
|
|
||||||
return ok
|
|
||||||
|
|
||||||
def run(self, txn: Transaction, post: Posting, post_index: int) -> Iterable[errormod._BaseError]:
|
|
||||||
errors: List[errormod._BaseError] = []
|
|
||||||
if not self._should_check(txn, post):
|
|
||||||
return errors
|
|
||||||
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
||||||
set_value = source_value
|
set_value = source_value
|
||||||
|
error: Optional[Error] = None
|
||||||
if source_value is None:
|
if source_value is None:
|
||||||
try:
|
try:
|
||||||
set_value = self._default_value(txn, post)
|
set_value = self._default_value(txn, post)
|
||||||
except errormod._BaseError as error:
|
except errormod._BaseError as error_:
|
||||||
errors.append(error)
|
error = error_
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
set_value = self.VALUES_ENUM[source_value]
|
set_value = self.VALUES_ENUM[source_value]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
errors.append(errormod.InvalidMetadataError(
|
error = errormod.InvalidMetadataError(
|
||||||
txn, post, self.METADATA_KEY, source_value,
|
txn, post, self.METADATA_KEY, source_value,
|
||||||
))
|
)
|
||||||
if not errors:
|
if error is None:
|
||||||
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
|
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value)
|
||||||
return errors
|
else:
|
||||||
|
yield error
|
||||||
|
|
|
@ -15,11 +15,14 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from . import core
|
from . import core
|
||||||
|
from .._typing import (
|
||||||
|
MetaValueEnum,
|
||||||
|
Posting,
|
||||||
|
Transaction,
|
||||||
|
)
|
||||||
|
|
||||||
class MetaExpenseAllocation(core.PostingChecker):
|
class MetaExpenseAllocation(core._NormalizePostingMetadataHook):
|
||||||
ACCOUNTS = ('Expenses:',)
|
VALUES_ENUM = core.MetadataEnum('expense-allocation', {
|
||||||
METADATA_KEY = 'expense-allocation'
|
|
||||||
VALUES_ENUM = core.MetadataEnum(METADATA_KEY, {
|
|
||||||
'administration',
|
'administration',
|
||||||
'fundraising',
|
'fundraising',
|
||||||
'program',
|
'program',
|
||||||
|
@ -32,5 +35,8 @@ class MetaExpenseAllocation(core.PostingChecker):
|
||||||
'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'],
|
'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _default_value(self, txn, post):
|
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
||||||
|
return post.account.startswith('Expenses:')
|
||||||
|
|
||||||
|
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
||||||
return self.DEFAULT_VALUES.get(post.account, 'program')
|
return self.DEFAULT_VALUES.get(post.account, 'program')
|
||||||
|
|
|
@ -17,13 +17,15 @@
|
||||||
import decimal
|
import decimal
|
||||||
|
|
||||||
from . import core
|
from . import core
|
||||||
|
from .._typing import (
|
||||||
|
Posting,
|
||||||
|
Transaction,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_STOP_AMOUNT = decimal.Decimal(0)
|
DEFAULT_STOP_AMOUNT = decimal.Decimal(0)
|
||||||
|
|
||||||
class MetaTaxImplication(core.PostingChecker):
|
class MetaTaxImplication(core._NormalizePostingMetadataHook):
|
||||||
ACCOUNTS = ('Assets:',)
|
VALUES_ENUM = core.MetadataEnum('tax-implication', [
|
||||||
METADATA_KEY = 'tax-implication'
|
|
||||||
VALUES_ENUM = core.MetadataEnum(METADATA_KEY, [
|
|
||||||
'1099',
|
'1099',
|
||||||
'Accountant-Advises-No-1099',
|
'Accountant-Advises-No-1099',
|
||||||
'Bank-Transfer',
|
'Bank-Transfer',
|
||||||
|
@ -43,8 +45,9 @@ class MetaTaxImplication(core.PostingChecker):
|
||||||
'W2',
|
'W2',
|
||||||
], {})
|
], {})
|
||||||
|
|
||||||
def _should_check(self, txn, post):
|
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
||||||
return (
|
return bool(
|
||||||
super()._should_check(txn, post)
|
post.account.startswith('Assets:')
|
||||||
|
and post.units.number
|
||||||
and post.units.number < DEFAULT_STOP_AMOUNT
|
and post.units.number < DEFAULT_STOP_AMOUNT
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,7 +44,7 @@ def test_valid_values_on_postings(src_value, set_value):
|
||||||
('Expenses:General', 25, {TEST_KEY: src_value}),
|
('Expenses:General', 25, {TEST_KEY: src_value}),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def test_invalid_values_on_postings(src_value):
|
||||||
('Expenses:General', 25, {TEST_KEY: src_value}),
|
('Expenses:General', 25, {TEST_KEY: src_value}),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert errors
|
assert errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||||
|
@ -65,7 +65,7 @@ def test_valid_values_on_transactions(src_value, set_value):
|
||||||
('Expenses:General', 25),
|
('Expenses:General', 25),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ def test_invalid_values_on_transactions(src_value):
|
||||||
('Expenses:General', 25),
|
('Expenses:General', 25),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert errors
|
assert errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('account', [
|
@pytest.mark.parametrize('account', [
|
||||||
|
@ -92,7 +92,7 @@ def test_non_expense_accounts_skipped(account):
|
||||||
('Expenses:General', 25, {TEST_KEY: 'program'}),
|
('Expenses:General', 25, {TEST_KEY: 'program'}),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[0], 0)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('account,set_value', [
|
@pytest.mark.parametrize('account,set_value', [
|
||||||
|
@ -108,7 +108,7 @@ def test_default_values(account, set_value):
|
||||||
(account, 25),
|
(account, 25),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert txn.postings[-1].meta[TEST_KEY] == set_value
|
assert txn.postings[-1].meta[TEST_KEY] == set_value
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ def test_default_value_set_in_date_range(date, set_value):
|
||||||
('Expenses:General', 25),
|
('Expenses:General', 25),
|
||||||
])
|
])
|
||||||
checker = meta_expense_allocation.MetaExpenseAllocation()
|
checker = meta_expense_allocation.MetaExpenseAllocation()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
got_value = (txn.postings[-1].meta or {}).get(TEST_KEY)
|
got_value = (txn.postings[-1].meta or {}).get(TEST_KEY)
|
||||||
assert bool(got_value) == bool(set_value)
|
assert bool(got_value) == bool(set_value)
|
||||||
|
|
|
@ -56,7 +56,7 @@ def test_valid_values_on_postings(src_value, set_value):
|
||||||
('Assets:Cash', -25, {TEST_KEY: src_value}),
|
('Assets:Cash', -25, {TEST_KEY: src_value}),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ def test_invalid_values_on_postings(src_value):
|
||||||
('Assets:Cash', -25, {TEST_KEY: src_value}),
|
('Assets:Cash', -25, {TEST_KEY: src_value}),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert errors
|
assert errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
@pytest.mark.parametrize('src_value,set_value', VALID_VALUES.items())
|
||||||
|
@ -77,7 +77,7 @@ def test_valid_values_on_transactions(src_value, set_value):
|
||||||
('Assets:Cash', -25),
|
('Assets:Cash', -25),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
assert txn.postings[-1].meta.get(TEST_KEY) == set_value
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ def test_invalid_values_on_transactions(src_value):
|
||||||
('Assets:Cash', -25),
|
('Assets:Cash', -25),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert errors
|
assert errors
|
||||||
|
|
||||||
@pytest.mark.parametrize('account', [
|
@pytest.mark.parametrize('account', [
|
||||||
|
@ -102,7 +102,7 @@ def test_non_asset_accounts_skipped(account):
|
||||||
('Assets:Cash', -25, {TEST_KEY: 'USA-Corporation'}),
|
('Assets:Cash', -25, {TEST_KEY: 'USA-Corporation'}),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[0], 0)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
|
|
||||||
def test_asset_credits_skipped():
|
def test_asset_credits_skipped():
|
||||||
|
@ -111,7 +111,7 @@ def test_asset_credits_skipped():
|
||||||
('Assets:Cash', 25),
|
('Assets:Cash', 25),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert not errors
|
assert not errors
|
||||||
assert not txn.postings[-1].meta
|
assert not txn.postings[-1].meta
|
||||||
|
|
||||||
|
@ -128,5 +128,5 @@ def test_default_value_set_in_date_range(date, need_value):
|
||||||
('Assets:Cash', -25),
|
('Assets:Cash', -25),
|
||||||
])
|
])
|
||||||
checker = meta_tax_implication.MetaTaxImplication()
|
checker = meta_tax_implication.MetaTaxImplication()
|
||||||
errors = checker.run(txn, txn.postings[-1], -1)
|
errors = list(checker.run(txn))
|
||||||
assert bool(errors) == bool(need_value)
|
assert bool(errors) == bool(need_value)
|
||||||
|
|
|
@ -25,28 +25,28 @@ def hook_names(hooks, key):
|
||||||
|
|
||||||
def test_default_registrations():
|
def test_default_registrations():
|
||||||
hooks = plugin.HOOK_REGISTRY.group_by_directive()
|
hooks = plugin.HOOK_REGISTRY.group_by_directive()
|
||||||
post_hook_names = hook_names(hooks, 'Posting')
|
txn_hook_names = hook_names(hooks, 'Transaction')
|
||||||
assert len(post_hook_names) >= 2
|
assert len(txn_hook_names) >= 2
|
||||||
assert 'MetaExpenseAllocation' in post_hook_names
|
assert 'MetaExpenseAllocation' in txn_hook_names
|
||||||
assert 'MetaTaxImplication' in post_hook_names
|
assert 'MetaTaxImplication' in txn_hook_names
|
||||||
|
|
||||||
def test_exclude_single():
|
def test_exclude_single():
|
||||||
hooks = plugin.HOOK_REGISTRY.group_by_directive('-expense-allocation')
|
hooks = plugin.HOOK_REGISTRY.group_by_directive('-expense-allocation')
|
||||||
post_hook_names = hook_names(hooks, 'Posting')
|
txn_hook_names = hook_names(hooks, 'Transaction')
|
||||||
assert post_hook_names
|
assert txn_hook_names
|
||||||
assert 'MetaExpenseAllocation' not in post_hook_names
|
assert 'MetaExpenseAllocation' not in txn_hook_names
|
||||||
|
|
||||||
def test_exclude_group_then_include_single():
|
def test_exclude_group_then_include_single():
|
||||||
hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expense-allocation')
|
hooks = plugin.HOOK_REGISTRY.group_by_directive('-metadata expense-allocation')
|
||||||
post_hook_names = hook_names(hooks, 'Posting')
|
txn_hook_names = hook_names(hooks, 'Transaction')
|
||||||
assert 'MetaExpenseAllocation' in post_hook_names
|
assert 'MetaExpenseAllocation' in txn_hook_names
|
||||||
assert 'MetaTaxImplication' not in post_hook_names
|
assert 'MetaTaxImplication' not in txn_hook_names
|
||||||
|
|
||||||
def test_include_group_then_exclude_single():
|
def test_include_group_then_exclude_single():
|
||||||
hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -tax-implication')
|
hooks = plugin.HOOK_REGISTRY.group_by_directive('metadata -tax-implication')
|
||||||
post_hook_names = hook_names(hooks, 'Posting')
|
txn_hook_names = hook_names(hooks, 'Transaction')
|
||||||
assert 'MetaExpenseAllocation' in post_hook_names
|
assert 'MetaExpenseAllocation' in txn_hook_names
|
||||||
assert 'MetaTaxImplication' not in post_hook_names
|
assert 'MetaTaxImplication' not in txn_hook_names
|
||||||
|
|
||||||
def test_unknown_group_name():
|
def test_unknown_group_name():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|
|
@ -18,14 +18,15 @@ import pytest
|
||||||
|
|
||||||
from . import testutil
|
from . import testutil
|
||||||
|
|
||||||
from conservancy_beancount import plugin
|
from conservancy_beancount import plugin, _typing
|
||||||
|
|
||||||
CONFIG_MAP = {}
|
CONFIG_MAP = {}
|
||||||
HOOK_REGISTRY = plugin.HookRegistry()
|
HOOK_REGISTRY = plugin.HookRegistry()
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
class TransactionCounter:
|
class TransactionCounter:
|
||||||
HOOK_GROUPS = frozenset(['Transaction', 'counter'])
|
DIRECTIVE = _typing.Transaction
|
||||||
|
HOOK_GROUPS = frozenset()
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
return ['txn:{}'.format(id(txn))]
|
return ['txn:{}'.format(id(txn))]
|
||||||
|
@ -33,10 +34,11 @@ class TransactionCounter:
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
class PostingCounter(TransactionCounter):
|
class PostingCounter(TransactionCounter):
|
||||||
HOOK_GROUPS = frozenset(['Posting', 'counter'])
|
DIRECTIVE = _typing.Transaction
|
||||||
|
HOOK_GROUPS = frozenset(['posting'])
|
||||||
|
|
||||||
def run(self, txn, post, post_index):
|
def run(self, txn):
|
||||||
return ['post:{}'.format(id(post))]
|
return ['post:{}'.format(id(post)) for post in txn.postings]
|
||||||
|
|
||||||
|
|
||||||
def map_errors(errors):
|
def map_errors(errors):
|
||||||
|
@ -74,7 +76,7 @@ def test_with_posting_hooks_only():
|
||||||
('Liabilites:CreditCard', -10),
|
('Liabilites:CreditCard', -10),
|
||||||
]),
|
]),
|
||||||
]
|
]
|
||||||
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'Posting', HOOK_REGISTRY)
|
out_entries, errors = plugin.run(in_entries, CONFIG_MAP, 'posting', HOOK_REGISTRY)
|
||||||
assert len(out_entries) == 2
|
assert len(out_entries) == 2
|
||||||
errmap = map_errors(errors)
|
errmap = map_errors(errors)
|
||||||
assert len(errmap.get('txn', '')) == 0
|
assert len(errmap.get('txn', '')) == 0
|
||||||
|
|
Loading…
Reference in a new issue