reports.core: Start Balance class.

This commit is contained in:
Brett Smith 2020-04-12 11:00:41 -04:00
parent 219cd4bc37
commit 5aa30e5456
4 changed files with 141 additions and 16 deletions

View file

@ -23,11 +23,60 @@ from .. import data
from typing import (
overload,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
class Balance(Mapping[str, data.Amount]):
"""A collection of amounts mapped by currency
Each key is a Beancount currency string, and each value represents the
balance in that currency.
"""
__slots__ = ('_currency_map',)
def __init__(self,
source: Union[Iterable[Tuple[str, data.Amount]],
Mapping[str, data.Amount]]=(),
) -> None:
if isinstance(source, Mapping):
source = source.items()
self._currency_map = {
currency: amount.number for currency, amount in source
}
def __repr__(self) -> str:
return f"{type(self).__name__}({self._currency_map!r})"
def __getitem__(self, key: str) -> data.Amount:
return data.Amount(self._currency_map[key], key)
def __iter__(self) -> Iterator[str]:
return iter(self._currency_map)
def __len__(self) -> int:
return len(self._currency_map)
def is_zero(self) -> bool:
return all(number == 0 for number in self._currency_map.values())
class MutableBalance(Balance):
__slots__ = ()
def add_amount(self, amount: data.Amount) -> None:
try:
self._currency_map[amount.currency] += amount.number
except KeyError:
self._currency_map[amount.currency] = amount.number
class RelatedPostings(Sequence[data.Posting]):
"""Collect and query related postings
@ -72,8 +121,8 @@ class RelatedPostings(Sequence[data.Posting]):
def add(self, post: data.Posting) -> None:
self._postings.append(post)
def balance(self) -> Sequence[data.Amount]:
currency_balance: Dict[str, Decimal] = collections.defaultdict(Decimal)
def balance(self) -> Balance:
balance = MutableBalance()
for post in self:
currency_balance[post.units.currency] += post.units.number
return [data.Amount(number, key) for key, number in currency_balance.items()]
balance.add_amount(post.units)
return balance

View file

@ -0,0 +1,68 @@
"""test_reports_balance - Unit tests for reports.core.Balance"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import itertools
from decimal import Decimal
import pytest
from . import testutil
from conservancy_beancount.reports import core
def test_empty_balance():
balance = core.Balance()
assert not balance
assert len(balance) == 0
assert balance.is_zero()
with pytest.raises(KeyError):
balance['USD']
@pytest.mark.parametrize('currencies', [
'USD',
'EUR GBP',
'JPY INR BRL',
])
def test_zero_balance(currencies):
keys = currencies.split()
balance = core.Balance(testutil.balance_map((key, 0) for key in keys))
assert balance
assert len(balance) == len(keys)
assert balance.is_zero()
assert all(balance[key].number == 0 for key in keys)
assert all(balance[key].currency == key for key in keys)
@pytest.mark.parametrize('currencies', [
'USD',
'EUR GBP',
'JPY INR BRL',
])
def test_nonzero_balance(currencies):
amounts = testutil.balance_map(zip(currencies.split(), itertools.count(110, 100)))
balance = core.Balance(amounts.items())
assert balance
assert len(balance) == len(amounts)
assert not balance.is_zero()
assert all(balance[key] == amt for key, amt in amounts.items())
def test_mixed_balance():
amounts = testutil.balance_map(USD=0, EUR=120)
balance = core.Balance(amounts.items())
assert balance
assert len(balance) == 2
assert not balance.is_zero()
assert all(balance[key] == amt for key, amt in amounts.items())

View file

@ -17,6 +17,8 @@
import datetime
import itertools
from decimal import Decimal
import pytest
from . import testutil
@ -53,17 +55,17 @@ def donation(amount, currency='USD', date=None, other_acct='Assets:Cash', **meta
def test_balance():
related = core.RelatedPostings()
related.add(data.Posting.from_beancount(donation(10), 0))
assert related.balance() == [testutil.Amount(-10)]
assert related.balance() == testutil.balance_map(USD=-10)
related.add(data.Posting.from_beancount(donation(15), 0))
assert related.balance() == [testutil.Amount(-25)]
assert related.balance() == testutil.balance_map(USD=-25)
related.add(data.Posting.from_beancount(donation(20), 0))
assert related.balance() == [testutil.Amount(-45)]
assert related.balance() == testutil.balance_map(USD=-45)
def test_balance_zero():
related = core.RelatedPostings()
related.add(data.Posting.from_beancount(donation(10), 0))
related.add(data.Posting.from_beancount(donation(-10), 0))
assert related.balance() == [testutil.Amount(0)]
assert related.balance().is_zero()
def test_balance_multiple_currencies():
related = core.RelatedPostings()
@ -71,17 +73,11 @@ def test_balance_multiple_currencies():
related.add(data.Posting.from_beancount(donation(15, 'GBP'), 0))
related.add(data.Posting.from_beancount(donation(20, 'EUR'), 0))
related.add(data.Posting.from_beancount(donation(25, 'EUR'), 0))
assert set(related.balance()) == {
testutil.Amount(-25, 'GBP'),
testutil.Amount(-45, 'EUR'),
}
assert related.balance() == testutil.balance_map(EUR=-45, GBP=-25)
def test_balance_multiple_currencies_one_zero():
related = core.RelatedPostings()
related.add(data.Posting.from_beancount(donation(10, 'EUR'), 0))
related.add(data.Posting.from_beancount(donation(15, 'USD'), 0))
related.add(data.Posting.from_beancount(donation(-10, 'EUR'), 0))
assert set(related.balance()) == {
testutil.Amount(-15, 'USD'),
testutil.Amount(0, 'EUR'),
}
assert related.balance() == testutil.balance_map(EUR=0, USD=-15)

View file

@ -113,6 +113,18 @@ OPENING_EQUITY_ACCOUNTS = itertools.cycle([
'Equity:OpeningBalance',
])
def balance_map(source=None, **kwargs):
# The source and/or kwargs should map currency name strings to
# things you can pass to Decimal (a decimal string, an int, etc.)
# This returns a dict that maps currency name strings to Amount instances.
retval = {}
if source is not None:
retval.update((currency, Amount(number, currency))
for currency, number in source)
if kwargs:
retval.update(balance_map(kwargs.items()))
return retval
class Transaction:
def __init__(self,
date=FY_MID_DATE, flag='*', payee=None,