reports: Make RelatedPostings an immutable data structure.

This was an early mistake, it makes data consistency mistakes too
easy, and I only used it once so far in actual code. Going to fix
this now so I can more safely build on top of this data structure.
This commit is contained in:
Brett Smith 2020-05-30 17:31:21 -04:00
parent dd949a4866
commit b37d7a3024
4 changed files with 269 additions and 131 deletions

View file

@ -63,6 +63,7 @@ import collections
import datetime import datetime
import enum import enum
import logging import logging
import operator
import re import re
import sys import sys
@ -72,6 +73,8 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
FrozenSet,
List,
Mapping, Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
@ -79,6 +82,7 @@ from typing import (
Set, Set,
TextIO, TextIO,
Tuple, Tuple,
Union,
) )
from ..beancount_types import ( from ..beancount_types import (
Error, Error,
@ -100,11 +104,15 @@ from .. import rtutil
PROGNAME = 'accrual-report' PROGNAME = 'accrual-report'
PostGroups = Mapping[Optional[MetaValue], core.RelatedPostings] PostGroups = Mapping[Optional[MetaValue], 'AccrualPostings']
RTObject = Mapping[str, str] RTObject = Mapping[str, str]
logger = logging.getLogger('conservancy_beancount.reports.accrual') logger = logging.getLogger('conservancy_beancount.reports.accrual')
class Sentinel:
pass
class Account(NamedTuple): class Account(NamedTuple):
name: str name: str
balance_paid: Callable[[core.Balance], bool] balance_paid: Callable[[core.Balance], bool]
@ -135,22 +143,95 @@ class AccrualAccount(enum.Enum):
} }
class AccrualPostings(core.RelatedPostings):
def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc]
def meta_getter(post: data.Posting) -> MetaValue:
return post.meta.get(key)
return meta_getter
_FIELDS: Dict[str, Callable[[data.Posting], MetaValue]] = {
'account': operator.attrgetter('account'),
'contract': _meta_getter('contract'),
'cost': operator.attrgetter('cost'),
'entity': _meta_getter('entity'),
'invoice': _meta_getter('invoice'),
'purchase_order': _meta_getter('purchase-order'),
}
_INVOICE_COUNTER: Dict[str, int] = collections.defaultdict(int)
INCONSISTENT = Sentinel()
__slots__ = (
'accrual_type',
'account',
'accounts',
'contract',
'contracts',
'cost',
'costs',
'entity',
'entitys',
'entities',
'invoice',
'invoices',
'purchase_order',
'purchase_orders',
)
def __init__(self,
source: Iterable[data.Posting]=(),
*,
_can_own: bool=False,
) -> None:
super().__init__(source, _can_own=_can_own)
# The following type declarations tell mypy about values set in the for
# loop that are important enough to be referenced directly elsewhere.
self.account: Union[data.Account, Sentinel]
self.entitys: FrozenSet[MetaValue]
self.invoice: Union[MetaValue, Sentinel]
for name, get_func in self._FIELDS.items():
values = frozenset(get_func(post) for post in self)
setattr(self, f'{name}s', values)
if len(values) == 1:
one_value = next(iter(values))
else:
one_value = self.INCONSISTENT
setattr(self, name, one_value)
# Correct spelling = bug prevention for future users of this class.
self.entities = self.entitys
if self.account is self.INCONSISTENT:
self.accrual_type: Optional[AccrualAccount] = None
else:
self.accrual_type = AccrualAccount.classify(self)
def report_inconsistencies(self) -> Iterable[Error]:
for field_name, get_func in self._FIELDS.items():
if getattr(self, field_name) is self.INCONSISTENT:
for post in self:
errmsg = 'inconsistent {} for invoice {}: {}'.format(
field_name.replace('_', '-'),
self.invoice or "<none>",
get_func(post),
)
yield Error(post.meta, errmsg, post.meta.txn)
class BaseReport: class BaseReport:
def __init__(self, out_file: TextIO) -> None: def __init__(self, out_file: TextIO) -> None:
self.out_file = out_file self.out_file = out_file
self.logger = logger.getChild(type(self).__name__) self.logger = logger.getChild(type(self).__name__)
def _since_last_nonzero(self, posts: core.RelatedPostings) -> core.RelatedPostings: def _since_last_nonzero(self, posts: AccrualPostings) -> AccrualPostings:
retval = core.RelatedPostings() for index, (post, balance) in enumerate(posts.iter_with_balance()):
for post in posts: if balance.is_zero():
if retval.balance().is_zero(): start_index = index
retval.clear() try:
retval.add(post) empty = start_index == index
return retval except NameError:
empty = True
return posts if empty else AccrualPostings(posts[start_index + 1:])
def _report(self, def _report(self,
invoice: str, invoice: str,
posts: core.RelatedPostings, posts: AccrualPostings,
index: int, index: int,
) -> Iterable[str]: ) -> Iterable[str]:
raise NotImplementedError("BaseReport._report") raise NotImplementedError("BaseReport._report")
@ -164,7 +245,7 @@ class BaseReport:
class BalanceReport(BaseReport): class BalanceReport(BaseReport):
def _report(self, def _report(self,
invoice: str, invoice: str,
posts: core.RelatedPostings, posts: AccrualPostings,
index: int, index: int,
) -> Iterable[str]: ) -> Iterable[str]:
posts = self._since_last_nonzero(posts) posts = self._since_last_nonzero(posts)
@ -182,7 +263,7 @@ class OutgoingReport(BaseReport):
self.rt_client = rt_client self.rt_client = rt_client
self.rt_wrapper = rtutil.RT(rt_client) self.rt_wrapper = rtutil.RT(rt_client)
def _primary_rt_id(self, posts: core.RelatedPostings) -> rtutil.TicketAttachmentIds: def _primary_rt_id(self, posts: AccrualPostings) -> rtutil.TicketAttachmentIds:
rt_ids = posts.all_meta_links('rt-id') rt_ids = posts.all_meta_links('rt-id')
rt_ids_count = len(rt_ids) rt_ids_count = len(rt_ids)
if rt_ids_count != 1: if rt_ids_count != 1:
@ -195,7 +276,7 @@ class OutgoingReport(BaseReport):
def _report(self, def _report(self,
invoice: str, invoice: str,
posts: core.RelatedPostings, posts: AccrualPostings,
index: int, index: int,
) -> Iterable[str]: ) -> Iterable[str]:
posts = self._since_last_nonzero(posts) posts = self._since_last_nonzero(posts)
@ -329,28 +410,6 @@ class SearchTerm(NamedTuple):
) )
return cls(key, pattern) return cls(key, pattern)
def _consistency_check_one_thing(
key: MetaValue,
related: core.RelatedPostings,
get_name: str,
get_func: Callable[[data.Posting], Any],
) -> Iterable[Error]:
values = {get_func(post) for post in related}
if len(values) != 1:
for post in related:
errmsg = f'inconsistent {get_name} for invoice {key}: {get_func(post)}'
yield Error(post.meta, errmsg, post.meta.txn)
def consistency_check(groups: PostGroups) -> Iterable[Error]:
errfmt = 'inconsistent {} for invoice {}: {{}}'
for key, related in groups.items():
yield from _consistency_check_one_thing(
key, related, 'cost', lambda post: post.cost,
)
for checked_meta in ['contract', 'entity', 'purchase-order']:
yield from _consistency_check_one_thing(
key, related, checked_meta, lambda post: post.meta.get(checked_meta),
)
def filter_search(postings: Iterable[data.Posting], def filter_search(postings: Iterable[data.Posting],
search_terms: Iterable[SearchTerm], search_terms: Iterable[SearchTerm],
@ -421,16 +480,16 @@ def main(arglist: Optional[Sequence[str]]=None,
} }
load_errors = [Error(source, "no books to load in configuration", None)] load_errors = [Error(source, "no books to load in configuration", None)]
postings = filter_search(data.Posting.from_entries(entries), args.search_terms) postings = filter_search(data.Posting.from_entries(entries), args.search_terms)
groups = core.RelatedPostings.group_by_meta(postings, 'invoice') groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice'))
groups = AccrualAccount.filter_paid_accruals(groups) or groups groups = AccrualAccount.filter_paid_accruals(groups) or groups
meta_errors = consistency_check(groups)
returncode = 0 returncode = 0
for error in load_errors: for error in load_errors:
bc_printer.print_error(error, file=stderr) bc_printer.print_error(error, file=stderr)
returncode |= ReturnFlag.LOAD_ERRORS returncode |= ReturnFlag.LOAD_ERRORS
for error in meta_errors: for related in groups.values():
bc_printer.print_error(error, file=stderr) for error in related.report_inconsistencies():
returncode |= ReturnFlag.CONSISTENCY_ERRORS bc_printer.print_error(error, file=stderr)
returncode |= ReturnFlag.CONSISTENCY_ERRORS
if args.report_type is None: if args.report_type is None:
args.report_type = ReportType.default_for(groups) args.report_type = ReportType.default_for(groups)
if not groups: if not groups:

View file

@ -37,6 +37,8 @@ from typing import (
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
Type,
TypeVar,
Union, Union,
) )
from ..beancount_types import ( from ..beancount_types import (
@ -45,6 +47,7 @@ from ..beancount_types import (
) )
DecimalCompat = data.DecimalCompat DecimalCompat = data.DecimalCompat
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')
class Balance(Mapping[str, data.Amount]): class Balance(Mapping[str, data.Amount]):
"""A collection of amounts mapped by currency """A collection of amounts mapped by currency
@ -162,15 +165,23 @@ class RelatedPostings(Sequence[data.Posting]):
""" """
__slots__ = ('_postings',) __slots__ = ('_postings',)
def __init__(self, source: Iterable[data.Posting]=()) -> None: def __init__(self,
self._postings: List[data.Posting] = list(source) source: Iterable[data.Posting]=(),
*,
_can_own: bool=False,
) -> None:
self._postings: List[data.Posting]
if _can_own and isinstance(source, list):
self._postings = source
else:
self._postings = list(source)
@classmethod @classmethod
def group_by_meta(cls, def group_by_meta(cls: Type[RelatedType],
postings: Iterable[data.Posting], postings: Iterable[data.Posting],
key: MetaKey, key: MetaKey,
default: Optional[MetaValue]=None, default: Optional[MetaValue]=None,
) -> Mapping[Optional[MetaValue], 'RelatedPostings']: ) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
"""Relate postings by metadata value """Relate postings by metadata value
This method takes an iterable of postings and returns a mapping. This method takes an iterable of postings and returns a mapping.
@ -178,32 +189,29 @@ class RelatedPostings(Sequence[data.Posting]):
The values are RelatedPostings instances that contain all the postings The values are RelatedPostings instances that contain all the postings
that had that same metadata value. that had that same metadata value.
""" """
retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls) mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list)
for post in postings: for post in postings:
retval[post.meta.get(key, default)].add(post) mapping[post.meta.get(key, default)].append(post)
retval.default_factory = None for value, posts in mapping.items():
return retval yield value, cls(posts, _can_own=True)
@overload @overload
def __getitem__(self, index: int) -> data.Posting: ... def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
@overload @overload
def __getitem__(self, s: slice) -> Sequence[data.Posting]: ... def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
def __getitem__(self, def __getitem__(self: RelatedType,
index: Union[int, slice], index: Union[int, slice],
) -> Union[data.Posting, Sequence[data.Posting]]: ) -> Union[data.Posting, RelatedType]:
if isinstance(index, slice): if isinstance(index, slice):
raise NotImplementedError("RelatedPostings[slice]") return type(self)(self._postings[index], _can_own=True)
else: else:
return self._postings[index] return self._postings[index]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._postings) return len(self._postings)
def add(self, post: data.Posting) -> None:
self._postings.append(post)
def all_meta_links(self, key: MetaKey) -> Set[str]: def all_meta_links(self, key: MetaKey) -> Set[str]:
retval: Set[str] = set() retval: Set[str] = set()
for post in self: for post in self:
@ -213,9 +221,6 @@ class RelatedPostings(Sequence[data.Posting]):
pass pass
return retval return retval
def clear(self) -> None:
self._postings.clear()
def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]: def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
balance = MutableBalance() balance = MutableBalance()
for post in self: for post in self:

View file

@ -94,8 +94,8 @@ def check_link_regexp(regexp, match_s, first_link_only=False):
else: else:
assert end_match assert end_match
def relate_accruals_by_meta(postings, value, key='invoice'): def accruals_by_meta(postings, value, key='invoice', wrap_type=iter):
return core.RelatedPostings( return wrap_type(
post for post in postings post for post in postings
if post.meta.get(key) == value if post.meta.get(key) == value
and post.account.is_under('Assets:Receivable', 'Liabilities:Payable') and post.account.is_under('Assets:Receivable', 'Liabilities:Payable')
@ -200,22 +200,107 @@ def test_report_type_by_unknown_name(arg):
with pytest.raises(ValueError): with pytest.raises(ValueError):
accrual.ReportType.by_name(arg) accrual.ReportType.by_name(arg)
@pytest.mark.parametrize('acct_name', ACCOUNTS)
def test_accrual_postings_consistent_account(acct_name):
meta = {'invoice': '{acct_name} invoice.pdf'}
txn = testutil.Transaction(postings=[
(acct_name, 50, meta),
(acct_name, 25, meta),
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert related.account == acct_name
assert related.accounts == {acct_name}
@pytest.mark.parametrize('cost', [
testutil.Cost('1.2', 'USD'),
None,
])
def test_accrual_postings_consistent_cost(cost):
meta = {'invoice': 'FXinvoice.pdf'}
txn = testutil.Transaction(postings=[
(ACCOUNTS[0], 60, 'EUR', cost, meta),
(ACCOUNTS[0], 30, 'EUR', cost, meta),
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert related.cost == cost
assert related.costs == {cost}
@pytest.mark.parametrize('meta_key,acct_name', testutil.combine_values(
CONSISTENT_METADATA,
ACCOUNTS,
))
def test_accrual_postings_consistent_metadata(meta_key, acct_name):
meta_value = f'{meta_key}.pdf'
meta = {
meta_key: meta_value,
'invoice': f'invoice with {meta_key}.pdf',
}
txn = testutil.Transaction(postings=[
(acct_name, 70, meta),
(acct_name, 35, meta),
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
attr_name = meta_key.replace('-', '_')
assert getattr(related, attr_name) == meta_value
assert getattr(related, f'{attr_name}s') == {meta_value}
def test_accrual_postings_inconsistent_account():
meta = {'invoice': 'invoice.pdf'}
txn = testutil.Transaction(postings=[
(acct_name, index, meta)
for index, acct_name in enumerate(ACCOUNTS)
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert related.account is related.INCONSISTENT
assert related.accounts == set(ACCOUNTS)
def test_accrual_postings_inconsistent_cost():
meta = {'invoice': 'FXinvoice.pdf'}
costs = {
testutil.Cost('1.1', 'USD'),
testutil.Cost('1.2', 'USD'),
}
txn = testutil.Transaction(postings=[
(ACCOUNTS[0], 10, 'EUR', cost, meta)
for cost in costs
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert related.cost is related.INCONSISTENT
assert related.costs == costs
@pytest.mark.parametrize('meta_key,acct_name', testutil.combine_values(
CONSISTENT_METADATA,
ACCOUNTS,
))
def test_accrual_postings_inconsistent_metadata(meta_key, acct_name):
invoice = 'invoice with {meta_key}.pdf'
meta_value = f'{meta_key}.pdf'
txn = testutil.Transaction(postings=[
(acct_name, 20, {'invoice': invoice, meta_key: meta_value}),
(acct_name, 35, {'invoice': invoice}),
])
related = accrual.AccrualPostings(data.Posting.from_txn(txn))
attr_name = meta_key.replace('-', '_')
assert getattr(related, attr_name) is related.INCONSISTENT
assert getattr(related, f'{attr_name}s') == {meta_value, None}
@pytest.mark.parametrize('meta_key,account', testutil.combine_values( @pytest.mark.parametrize('meta_key,account', testutil.combine_values(
CONSISTENT_METADATA, CONSISTENT_METADATA,
ACCOUNTS, ACCOUNTS,
)) ))
def test_consistency_check_when_consistent(meta_key, account): def test_consistency_check_when_consistent(meta_key, account):
invoice = f'test-{meta_key}-invoice' invoice = f'test-{meta_key}-invoice'
meta_value = f'test-{meta_key}-value'
meta = { meta = {
'invoice': invoice, 'invoice': invoice,
meta_key: f'test-{meta_key}-value', meta_key: meta_value,
} }
txn = testutil.Transaction(postings=[ txn = testutil.Transaction(postings=[
(account, 100, meta), (account, 100, meta),
(account, -100, meta), (account, -100, meta),
]) ])
related = core.RelatedPostings(data.Posting.from_txn(txn)) related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert not list(accrual.consistency_check({invoice: related})) assert not list(related.report_inconsistencies())
@pytest.mark.parametrize('meta_key,account', testutil.combine_values( @pytest.mark.parametrize('meta_key,account', testutil.combine_values(
['approval', 'fx-rate', 'statement'], ['approval', 'fx-rate', 'statement'],
@ -227,8 +312,8 @@ def test_consistency_check_ignored_metadata(meta_key, account):
(account, 100, {'invoice': invoice, meta_key: 'credit'}), (account, 100, {'invoice': invoice, meta_key: 'credit'}),
(account, -100, {'invoice': invoice, meta_key: 'debit'}), (account, -100, {'invoice': invoice, meta_key: 'debit'}),
]) ])
related = core.RelatedPostings(data.Posting.from_txn(txn)) related = accrual.AccrualPostings(data.Posting.from_txn(txn))
assert not list(accrual.consistency_check({invoice: related})) assert not list(related.report_inconsistencies())
@pytest.mark.parametrize('meta_key,account', testutil.combine_values( @pytest.mark.parametrize('meta_key,account', testutil.combine_values(
CONSISTENT_METADATA, CONSISTENT_METADATA,
@ -240,8 +325,8 @@ def test_consistency_check_when_inconsistent(meta_key, account):
(account, 100, {'invoice': invoice, meta_key: 'credit', 'lineno': 1}), (account, 100, {'invoice': invoice, meta_key: 'credit', 'lineno': 1}),
(account, -100, {'invoice': invoice, meta_key: 'debit', 'lineno': 2}), (account, -100, {'invoice': invoice, meta_key: 'debit', 'lineno': 2}),
]) ])
related = core.RelatedPostings(data.Posting.from_txn(txn)) related = accrual.AccrualPostings(data.Posting.from_txn(txn))
errors = list(accrual.consistency_check({invoice: related})) errors = list(related.report_inconsistencies())
for exp_lineno, (actual, exp_msg) in enumerate(itertools.zip_longest(errors, [ for exp_lineno, (actual, exp_msg) in enumerate(itertools.zip_longest(errors, [
f'inconsistent {meta_key} for invoice {invoice}: credit', f'inconsistent {meta_key} for invoice {invoice}: credit',
f'inconsistent {meta_key} for invoice {invoice}: debit', f'inconsistent {meta_key} for invoice {invoice}: debit',
@ -257,8 +342,8 @@ def test_consistency_check_cost():
(account, 100, 'EUR', ('1.1251', 'USD'), {'invoice': invoice, 'lineno': 1}), (account, 100, 'EUR', ('1.1251', 'USD'), {'invoice': invoice, 'lineno': 1}),
(account, -100, 'EUR', ('1.125', 'USD'), {'invoice': invoice, 'lineno': 2}), (account, -100, 'EUR', ('1.125', 'USD'), {'invoice': invoice, 'lineno': 2}),
]) ])
related = core.RelatedPostings(data.Posting.from_txn(txn)) related = accrual.AccrualPostings(data.Posting.from_txn(txn))
errors = list(accrual.consistency_check({invoice: related})) errors = list(related.report_inconsistencies())
for post, err in itertools.zip_longest(txn.postings, errors): for post, err in itertools.zip_longest(txn.postings, errors):
assert err.message == f'inconsistent cost for invoice {invoice}: {post.cost}' assert err.message == f'inconsistent cost for invoice {invoice}: {post.cost}'
assert err.entry is txn assert err.entry is txn
@ -272,7 +357,7 @@ def run_outgoing(invoice, postings, rt_client=None):
if rt_client is None: if rt_client is None:
rt_client = RTClient() rt_client = RTClient()
if not isinstance(postings, core.RelatedPostings): if not isinstance(postings, core.RelatedPostings):
postings = relate_accruals_by_meta(postings, invoice) postings = accruals_by_meta(postings, invoice, wrap_type=accrual.AccrualPostings)
output = io.StringIO() output = io.StringIO()
report = accrual.OutgoingReport(rt_client, output) report = accrual.OutgoingReport(rt_client, output)
report.run({invoice: postings}) report.run({invoice: postings})
@ -285,7 +370,7 @@ def run_outgoing(invoice, postings, rt_client=None):
('rt://ticket/515/attachments/5150', "1,500.00 USD outstanding since 2020-05-15",), ('rt://ticket/515/attachments/5150', "1,500.00 USD outstanding since 2020-05-15",),
]) ])
def test_balance_report(accrual_postings, invoice, expected, caplog): def test_balance_report(accrual_postings, invoice, expected, caplog):
related = relate_accruals_by_meta(accrual_postings, invoice) related = accruals_by_meta(accrual_postings, invoice, wrap_type=accrual.AccrualPostings)
output = io.StringIO() output = io.StringIO()
report = accrual.BalanceReport(output) report = accrual.BalanceReport(output)
report.run({invoice: related}) report.run({invoice: related})

View file

@ -80,42 +80,27 @@ def test_balance_empty():
assert not balance assert not balance
assert balance.is_zero() assert balance.is_zero()
def test_balance_credit_card(credit_card_cycle): @pytest.mark.parametrize('index,expected', enumerate([
related = core.RelatedPostings() -110,
assert related.balance() == testutil.balance_map() 0,
expected = Decimal() -120,
for txn in credit_card_cycle: 0,
post = txn.postings[0] ]))
expected += post.units.number def test_balance_credit_card(credit_card_cycle, index, expected):
related.add(post) related = core.RelatedPostings(
assert related.balance() == testutil.balance_map(USD=expected) txn.postings[0] for txn in credit_card_cycle[:index + 1]
assert expected == 0 )
assert related.balance() == testutil.balance_map(USD=expected)
def test_clear_after_add():
related = core.RelatedPostings()
related.add(testutil.Posting('Income:Donations', -10))
assert related.balance()
related.clear()
assert not related.balance()
def test_clear_after_initialization():
related = core.RelatedPostings([
testutil.Posting('Income:Donations', -12),
])
assert related.balance()
related.clear()
assert not related.balance()
def check_iter_with_balance(entries): def check_iter_with_balance(entries):
expect_posts = [txn.postings[0] for txn in entries] expect_posts = [txn.postings[0] for txn in entries]
expect_balances = [] expect_balances = []
balance_tally = collections.defaultdict(Decimal) balance_tally = collections.defaultdict(Decimal)
related = core.RelatedPostings()
for post in expect_posts: for post in expect_posts:
number, currency = post.units number, currency = post.units
balance_tally[currency] += number balance_tally[currency] += number
expect_balances.append(testutil.balance_map(balance_tally.items())) expect_balances.append(testutil.balance_map(balance_tally.items()))
related.add(post) related = core.RelatedPostings(expect_posts)
for (post, balance), exp_post, exp_balance in zip( for (post, balance), exp_post, exp_balance in zip(
related.iter_with_balance(), related.iter_with_balance(),
expect_posts, expect_posts,
@ -195,48 +180,56 @@ def test_meta_values_empty():
assert related.meta_values('key') == set() assert related.meta_values('key') == set()
def test_meta_values_no_match(): def test_meta_values_no_match():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, metakey='metavalue')) testutil.Posting('Income:Donations', -1, metakey='metavalue'),
])
assert related.meta_values('key') == {None} assert related.meta_values('key') == {None}
def test_meta_values_no_match_default_given(): def test_meta_values_no_match_default_given():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, metakey='metavalue')) testutil.Posting('Income:Donations', -1, metakey='metavalue'),
])
assert related.meta_values('key', '') == {''} assert related.meta_values('key', '') == {''}
def test_meta_values_one_match(): def test_meta_values_one_match():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='metavalue')) testutil.Posting('Income:Donations', -1, key='metavalue'),
])
assert related.meta_values('key') == {'metavalue'} assert related.meta_values('key') == {'metavalue'}
def test_meta_values_some_match(): def test_meta_values_some_match():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='1')) testutil.Posting('Income:Donations', -1, key='1'),
related.add(testutil.Posting('Income:Donations', -2, metakey='2')) testutil.Posting('Income:Donations', -2, metakey='2'),
])
assert related.meta_values('key') == {'1', None} assert related.meta_values('key') == {'1', None}
def test_meta_values_some_match_default_given(): def test_meta_values_some_match_default_given():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='1')) testutil.Posting('Income:Donations', -1, key='1'),
related.add(testutil.Posting('Income:Donations', -2, metakey='2')) testutil.Posting('Income:Donations', -2, metakey='2'),
])
assert related.meta_values('key', '') == {'1', ''} assert related.meta_values('key', '') == {'1', ''}
def test_meta_values_all_match(): def test_meta_values_all_match():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='1')) testutil.Posting('Income:Donations', -1, key='1'),
related.add(testutil.Posting('Income:Donations', -2, key='2')) testutil.Posting('Income:Donations', -2, key='2'),
])
assert related.meta_values('key') == {'1', '2'} assert related.meta_values('key') == {'1', '2'}
def test_meta_values_all_match_one_value(): def test_meta_values_all_match_one_value():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='1')) testutil.Posting('Income:Donations', -1, key='1'),
related.add(testutil.Posting('Income:Donations', -2, key='1')) testutil.Posting('Income:Donations', -2, key='1'),
])
assert related.meta_values('key') == {'1'} assert related.meta_values('key') == {'1'}
def test_meta_values_all_match_default_given(): def test_meta_values_all_match_default_given():
related = core.RelatedPostings() related = core.RelatedPostings([
related.add(testutil.Posting('Income:Donations', -1, key='1')) testutil.Posting('Income:Donations', -1, key='1'),
related.add(testutil.Posting('Income:Donations', -2, key='2')) testutil.Posting('Income:Donations', -2, key='2'),
])
assert related.meta_values('key', '') == {'1', '2'} assert related.meta_values('key', '') == {'1', '2'}
def test_meta_values_many_types(): def test_meta_values_many_types():
@ -246,9 +239,10 @@ def test_meta_values_many_types():
testutil.Amount(5), testutil.Amount(5),
'rt:42', 'rt:42',
} }
related = core.RelatedPostings() related = core.RelatedPostings(
for index, value in enumerate(expected): testutil.Posting('Income:Donations', -index, key=value)
related.add(testutil.Posting('Income:Donations', -index, key=value)) for index, value in enumerate(expected)
)
assert related.meta_values('key') == expected assert related.meta_values('key') == expected
@pytest.mark.parametrize('count', range(3)) @pytest.mark.parametrize('count', range(3))
@ -289,23 +283,18 @@ def test_all_meta_links_multiples():
assert related.all_meta_links('approval') == testutil.LINK_METADATA_STRINGS assert related.all_meta_links('approval') == testutil.LINK_METADATA_STRINGS
def test_group_by_meta_zero(): def test_group_by_meta_zero():
assert len(core.RelatedPostings.group_by_meta([], 'metacurrency')) == 0 assert not list(core.RelatedPostings.group_by_meta([], 'metacurrency'))
def test_group_by_meta_key_error():
# Make sure the return value doesn't act like a defaultdict.
with pytest.raises(KeyError):
core.RelatedPostings.group_by_meta([], 'metakey')['metavalue']
def test_group_by_meta_one(credit_card_cycle): def test_group_by_meta_one(credit_card_cycle):
posting = next(post for post in data.Posting.from_entries(credit_card_cycle) posting = next(post for post in data.Posting.from_entries(credit_card_cycle)
if post.account.is_credit_card()) if post.account.is_credit_card())
actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency') actual = core.RelatedPostings.group_by_meta([posting], 'metacurrency')
assert set(actual) == {'USD'} assert set(key for key, _ in actual) == {'USD'}
def test_group_by_meta_many(two_accruals_three_payments): def test_group_by_meta_many(two_accruals_three_payments):
postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
if post.account == 'Assets:Receivable:Accounts'] if post.account == 'Assets:Receivable:Accounts']
actual = core.RelatedPostings.group_by_meta(postings, 'metacurrency') actual = dict(core.RelatedPostings.group_by_meta(postings, 'metacurrency'))
assert set(actual) == {'USD', 'EUR'} assert set(actual) == {'USD', 'EUR'}
for key, group in actual.items(): for key, group in actual.items():
assert 2 <= len(group) <= 3 assert 2 <= len(group) <= 3
@ -314,6 +303,6 @@ def test_group_by_meta_many(two_accruals_three_payments):
def test_group_by_meta_many_single_posts(two_accruals_three_payments): def test_group_by_meta_many_single_posts(two_accruals_three_payments):
postings = [post for post in data.Posting.from_entries(two_accruals_three_payments) postings = [post for post in data.Posting.from_entries(two_accruals_three_payments)
if post.account == 'Assets:Receivable:Accounts'] if post.account == 'Assets:Receivable:Accounts']
actual = core.RelatedPostings.group_by_meta(postings, 'metanumber') actual = dict(core.RelatedPostings.group_by_meta(postings, 'metanumber'))
assert set(actual) == {post.units.number for post in postings} assert set(actual) == {post.units.number for post in postings}
assert len(actual) == len(postings) assert len(actual) == len(postings)