data.balance_of: Take account predicates, not just names.

For increased flexibility.
In particular, now you can pass in Account boolean methods to
call those directly.
This commit is contained in:
Brett Smith 2020-04-08 14:16:57 -04:00
parent 28e59e7a3b
commit bb84cb5741
2 changed files with 42 additions and 29 deletions

View file

@ -261,26 +261,23 @@ class Posting(BasePosting):
def balance_of(txn: Transaction, def balance_of(txn: Transaction,
*accounts: str, *preds: Callable[[Account], Optional[bool]],
default: Optional[DecimalCompat]=None, default: Optional[DecimalCompat]=None,
) -> Optional[decimal.Decimal]: ) -> Optional[decimal.Decimal]:
"""Return the balance of specified postings in a transaction. """Return the balance of specified postings in a transaction.
Given a transaction and a series of account names, balance_of returns the Given a transaction and a series of account predicates, balance_of
balance of the amounts of all postings under those account names. returns the balance of the amounts of all postings with accounts that
match any of the predicates.
Account names are matched using Account.is_under. Refer to that docstring
for details about what matches.
If any of the postings have no amount, returns default. If any of the postings have no amount, returns default.
""" """
if default is not None:
default = decimal.Decimal(default)
retval = decimal.Decimal(0) retval = decimal.Decimal(0)
for post in txn.postings: for post in txn.postings:
if Account(post.account).is_under(*accounts): acct = Account(post.account)
if any(p(acct) for p in preds):
if post.units.number is None: if post.units.number is None:
return default return None if default is None else decimal.Decimal(default)
else: else:
retval += post.units.number retval += post.units.number
return retval return retval

View file

@ -15,6 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from decimal import Decimal from decimal import Decimal
from operator import methodcaller
import pytest import pytest
@ -22,6 +23,8 @@ from . import testutil
from conservancy_beancount import data from conservancy_beancount import data
is_cash_eq = data.Account.is_cash_equivalent
@pytest.fixture @pytest.fixture
def payable_payment_txn(): def payable_payment_txn():
return testutil.Transaction(postings=[ return testutil.Transaction(postings=[
@ -48,43 +51,56 @@ def multipost_one_none_txn():
('Assets:Checking', None), ('Assets:Checking', None),
]) ])
def balance_under(txn, *accts):
pred = methodcaller('is_under', *accts)
return data.balance_of(txn, pred)
def test_balance_of_simple_txn(): def test_balance_of_simple_txn():
txn = testutil.Transaction(postings=[ txn = testutil.Transaction(postings=[
('Assets:Cash', 50), ('Assets:Cash', 50),
('Income:Donations', -50), ('Income:Donations', -50),
]) ])
assert data.balance_of(txn, 'Assets') == 50 assert balance_under(txn, 'Assets') == 50
assert data.balance_of(txn, 'Income') == -50 assert balance_under(txn, 'Income') == -50
def test_zero_balance_of(payable_payment_txn): def test_zero_balance_of(payable_payment_txn):
assert data.balance_of(payable_payment_txn, 'Equity') == 0 assert balance_under(payable_payment_txn, 'Equity') == 0
assert data.balance_of(payable_payment_txn, 'Assets:Cash') == 0 assert balance_under(payable_payment_txn, 'Assets:Cash') == 0
assert data.balance_of(payable_payment_txn, 'Liabilities:CreditCard') == 0 assert balance_under(payable_payment_txn, 'Liabilities:CreditCard') == 0
def test_nonzero_balance_of(payable_payment_txn):
assert balance_under(payable_payment_txn, 'Assets', 'Expenses') == -50
assert balance_under(payable_payment_txn, 'Assets', 'Liabilities') == -5
def test_multiarg_balance_of():
txn = testutil.Transaction(postings=[
('Liabilities:CreditCard', 650),
('Expenses:BankingFees', 5),
('Assets:Checking', -655),
])
assert data.balance_of(txn, is_cash_eq, data.Account.is_credit_card) == -5
def test_balance_of_multipost_txn(payable_payment_txn): def test_balance_of_multipost_txn(payable_payment_txn):
assert data.balance_of(payable_payment_txn, 'Assets') == -55 assert data.balance_of(payable_payment_txn, is_cash_eq) == -55
def test_multiarg_balance_of(payable_payment_txn):
assert data.balance_of(payable_payment_txn, 'Assets', 'Expenses') == -50
assert data.balance_of(payable_payment_txn, 'Assets', 'Liabilities') == -5
def test_balance_of_uses_whole_account_names(payable_payment_txn):
assert data.balance_of(payable_payment_txn, 'Assets:Check') == 0
def test_balance_of_none_posting(none_posting_txn): def test_balance_of_none_posting(none_posting_txn):
assert data.balance_of(none_posting_txn, 'Assets') is None assert data.balance_of(none_posting_txn, is_cash_eq) is None
def test_balance_of_none_posting_with_default(none_posting_txn): def test_balance_of_none_posting_with_default(none_posting_txn):
expected = Decimal('Infinity') expected = Decimal('Infinity')
assert data.balance_of(none_posting_txn, 'Assets', default=expected) == expected assert expected == data.balance_of(
none_posting_txn, is_cash_eq, default=expected,
)
def test_balance_of_other_side_of_none_posting(none_posting_txn): def test_balance_of_other_side_of_none_posting(none_posting_txn):
assert data.balance_of(none_posting_txn, 'Income') == -30 assert balance_under(none_posting_txn, 'Income') == -30
assert data.balance_of(none_posting_txn, 'Expenses') == 3 assert balance_under(none_posting_txn, 'Expenses') == 3
def test_balance_of_multi_postings_one_none(multipost_one_none_txn): def test_balance_of_multi_postings_one_none(multipost_one_none_txn):
assert data.balance_of(multipost_one_none_txn, 'Assets') is None assert data.balance_of(multipost_one_none_txn, is_cash_eq) is None
def test_balance_of_multi_postings_one_none(multipost_one_none_txn): def test_balance_of_multi_postings_one_none(multipost_one_none_txn):
expected = Decimal('Infinity') expected = Decimal('Infinity')
assert data.balance_of(multipost_one_none_txn, 'Assets', default=expected) == expected assert expected == data.balance_of(
multipost_one_none_txn, is_cash_eq, default=expected,
)