diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 7c6debc..e4f168c 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -23,11 +23,60 @@ from .. import data from typing import ( overload, Dict, + Iterable, + Iterator, List, + Mapping, + Optional, Sequence, + Tuple, Union, ) +class Balance(Mapping[str, data.Amount]): + """A collection of amounts mapped by currency + + Each key is a Beancount currency string, and each value represents the + balance in that currency. + """ + __slots__ = ('_currency_map',) + + def __init__(self, + source: Union[Iterable[Tuple[str, data.Amount]], + Mapping[str, data.Amount]]=(), + ) -> None: + if isinstance(source, Mapping): + source = source.items() + self._currency_map = { + currency: amount.number for currency, amount in source + } + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._currency_map!r})" + + def __getitem__(self, key: str) -> data.Amount: + return data.Amount(self._currency_map[key], key) + + def __iter__(self) -> Iterator[str]: + return iter(self._currency_map) + + def __len__(self) -> int: + return len(self._currency_map) + + def is_zero(self) -> bool: + return all(number == 0 for number in self._currency_map.values()) + + +class MutableBalance(Balance): + __slots__ = () + + def add_amount(self, amount: data.Amount) -> None: + try: + self._currency_map[amount.currency] += amount.number + except KeyError: + self._currency_map[amount.currency] = amount.number + + class RelatedPostings(Sequence[data.Posting]): """Collect and query related postings @@ -72,8 +121,8 @@ class RelatedPostings(Sequence[data.Posting]): def add(self, post: data.Posting) -> None: self._postings.append(post) - def balance(self) -> Sequence[data.Amount]: - currency_balance: Dict[str, Decimal] = collections.defaultdict(Decimal) + def balance(self) -> Balance: + balance = MutableBalance() for post in self: - currency_balance[post.units.currency] += post.units.number - return [data.Amount(number, key) for key, number in currency_balance.items()] + balance.add_amount(post.units) + return balance diff --git a/tests/test_reports_balance.py b/tests/test_reports_balance.py new file mode 100644 index 0000000..68cb98c --- /dev/null +++ b/tests/test_reports_balance.py @@ -0,0 +1,68 @@ +"""test_reports_balance - Unit tests for reports.core.Balance""" +# Copyright © 2020 Brett Smith +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import itertools + +from decimal import Decimal + +import pytest + +from . import testutil + +from conservancy_beancount.reports import core + +def test_empty_balance(): + balance = core.Balance() + assert not balance + assert len(balance) == 0 + assert balance.is_zero() + with pytest.raises(KeyError): + balance['USD'] + +@pytest.mark.parametrize('currencies', [ + 'USD', + 'EUR GBP', + 'JPY INR BRL', +]) +def test_zero_balance(currencies): + keys = currencies.split() + balance = core.Balance(testutil.balance_map((key, 0) for key in keys)) + assert balance + assert len(balance) == len(keys) + assert balance.is_zero() + assert all(balance[key].number == 0 for key in keys) + assert all(balance[key].currency == key for key in keys) + +@pytest.mark.parametrize('currencies', [ + 'USD', + 'EUR GBP', + 'JPY INR BRL', +]) +def test_nonzero_balance(currencies): + amounts = testutil.balance_map(zip(currencies.split(), itertools.count(110, 100))) + balance = core.Balance(amounts.items()) + assert balance + assert len(balance) == len(amounts) + assert not balance.is_zero() + assert all(balance[key] == amt for key, amt in amounts.items()) + +def test_mixed_balance(): + amounts = testutil.balance_map(USD=0, EUR=120) + balance = core.Balance(amounts.items()) + assert balance + assert len(balance) == 2 + assert not balance.is_zero() + assert all(balance[key] == amt for key, amt in amounts.items()) diff --git a/tests/test_reports_related_postings.py b/tests/test_reports_related_postings.py index 9a70d8e..a915c4e 100644 --- a/tests/test_reports_related_postings.py +++ b/tests/test_reports_related_postings.py @@ -17,6 +17,8 @@ import datetime import itertools +from decimal import Decimal + import pytest from . import testutil @@ -53,17 +55,17 @@ def donation(amount, currency='USD', date=None, other_acct='Assets:Cash', **meta def test_balance(): related = core.RelatedPostings() related.add(data.Posting.from_beancount(donation(10), 0)) - assert related.balance() == [testutil.Amount(-10)] + assert related.balance() == testutil.balance_map(USD=-10) related.add(data.Posting.from_beancount(donation(15), 0)) - assert related.balance() == [testutil.Amount(-25)] + assert related.balance() == testutil.balance_map(USD=-25) related.add(data.Posting.from_beancount(donation(20), 0)) - assert related.balance() == [testutil.Amount(-45)] + assert related.balance() == testutil.balance_map(USD=-45) def test_balance_zero(): related = core.RelatedPostings() related.add(data.Posting.from_beancount(donation(10), 0)) related.add(data.Posting.from_beancount(donation(-10), 0)) - assert related.balance() == [testutil.Amount(0)] + assert related.balance().is_zero() def test_balance_multiple_currencies(): related = core.RelatedPostings() @@ -71,17 +73,11 @@ def test_balance_multiple_currencies(): related.add(data.Posting.from_beancount(donation(15, 'GBP'), 0)) related.add(data.Posting.from_beancount(donation(20, 'EUR'), 0)) related.add(data.Posting.from_beancount(donation(25, 'EUR'), 0)) - assert set(related.balance()) == { - testutil.Amount(-25, 'GBP'), - testutil.Amount(-45, 'EUR'), - } + assert related.balance() == testutil.balance_map(EUR=-45, GBP=-25) def test_balance_multiple_currencies_one_zero(): related = core.RelatedPostings() related.add(data.Posting.from_beancount(donation(10, 'EUR'), 0)) related.add(data.Posting.from_beancount(donation(15, 'USD'), 0)) related.add(data.Posting.from_beancount(donation(-10, 'EUR'), 0)) - assert set(related.balance()) == { - testutil.Amount(-15, 'USD'), - testutil.Amount(0, 'EUR'), - } + assert related.balance() == testutil.balance_map(EUR=0, USD=-15) diff --git a/tests/testutil.py b/tests/testutil.py index c33519a..f4ef5e8 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -113,6 +113,18 @@ OPENING_EQUITY_ACCOUNTS = itertools.cycle([ 'Equity:OpeningBalance', ]) +def balance_map(source=None, **kwargs): + # The source and/or kwargs should map currency name strings to + # things you can pass to Decimal (a decimal string, an int, etc.) + # This returns a dict that maps currency name strings to Amount instances. + retval = {} + if source is not None: + retval.update((currency, Amount(number, currency)) + for currency, number in source) + if kwargs: + retval.update(balance_map(kwargs.items())) + return retval + class Transaction: def __init__(self, date=FY_MID_DATE, flag='*', payee=None,