conservancy_beancount/conservancy_beancount/reports/core.py

256 lines
8 KiB
Python
Raw Normal View History

"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import collections
import operator
from decimal import Decimal
2020-05-28 08:39:50 -04:00
import babel.numbers # type:ignore[import]
from .. import data
from typing import (
overload,
2020-05-28 08:39:50 -04:00
Any,
Callable,
DefaultDict,
Dict,
2020-04-12 11:00:41 -04:00
Iterable,
Iterator,
List,
2020-04-12 11:00:41 -04:00
Mapping,
Optional,
Sequence,
Set,
2020-04-12 11:00:41 -04:00
Tuple,
Type,
TypeVar,
Union,
)
from ..beancount_types import (
MetaKey,
MetaValue,
)
DecimalCompat = data.DecimalCompat
RelatedType = TypeVar('RelatedType', bound='RelatedPostings')
2020-04-12 11:00:41 -04:00
class Balance(Mapping[str, data.Amount]):
"""A collection of amounts mapped by currency
Each key is a Beancount currency string, and each value represents the
balance in that currency.
"""
__slots__ = ('_currency_map',)
def __init__(self,
source: Union[Iterable[Tuple[str, data.Amount]],
Mapping[str, data.Amount]]=(),
) -> None:
if isinstance(source, Mapping):
source = source.items()
self._currency_map = {
currency: amount.number for currency, amount in source
}
def __repr__(self) -> str:
return f"{type(self).__name__}({self._currency_map!r})"
2020-04-29 11:37:38 -04:00
def __str__(self) -> str:
2020-05-28 08:39:50 -04:00
return self.format()
2020-04-29 11:37:38 -04:00
def __eq__(self, other: Any) -> bool:
if (self.is_zero()
and isinstance(other, Balance)
and other.is_zero()):
return True
else:
return super().__eq__(other)
2020-04-29 14:35:20 -04:00
def __neg__(self) -> 'Balance':
return type(self)(
(key, -amt) for key, amt in self.items()
)
2020-04-12 11:00:41 -04:00
def __getitem__(self, key: str) -> data.Amount:
return data.Amount(self._currency_map[key], key)
def __iter__(self) -> Iterator[str]:
return iter(self._currency_map)
def __len__(self) -> int:
return len(self._currency_map)
def _all_amounts(self,
op_func: Callable[[DecimalCompat, DecimalCompat], bool],
operand: DecimalCompat,
) -> bool:
return all(op_func(number, operand) for number in self._currency_map.values())
def eq_zero(self) -> bool:
"""Returns true if all amounts in the balance == 0."""
return self._all_amounts(operator.eq, 0)
is_zero = eq_zero
def ge_zero(self) -> bool:
"""Returns true if all amounts in the balance >= 0."""
return self._all_amounts(operator.ge, 0)
def le_zero(self) -> bool:
"""Returns true if all amounts in the balance <= 0."""
return self._all_amounts(operator.le, 0)
2020-04-12 11:00:41 -04:00
2020-05-28 08:39:50 -04:00
def format(self,
fmt: Optional[str]='#,#00.00 ¤¤',
2020-05-28 08:39:50 -04:00
sep: str=', ',
empty: str="Zero balance",
) -> str:
"""Formats the balance as a string with the given parameters
If the balance is zero, returns ``empty``. Otherwise, returns a string
with each amount in the balance formatted as ``fmt``, separated by
``sep``.
If you set ``fmt`` to None, amounts will be formatted according to the
user's locale. The default format is Beancount's input format.
2020-05-28 08:39:50 -04:00
"""
amounts = [amount for amount in self.values() if amount.number]
if not amounts:
return empty
amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
return sep.join(
babel.numbers.format_currency(amt.number, amt.currency, fmt)
for amt in amounts
)
2020-04-12 11:00:41 -04:00
class MutableBalance(Balance):
__slots__ = ()
def add_amount(self, amount: data.Amount) -> None:
try:
self._currency_map[amount.currency] += amount.number
except KeyError:
self._currency_map[amount.currency] = amount.number
class RelatedPostings(Sequence[data.Posting]):
"""Collect and query related postings
This class provides common functionality for collecting related postings
and running queries on them: iterating over them, tallying their balance,
etc.
This class doesn't know anything about how the postings are related. That's
entirely up to the caller.
A common pattern is to use this class with collections.defaultdict
to organize postings based on some key. See the group_by_meta classmethod
for an example.
"""
__slots__ = ('_postings',)
def __init__(self,
source: Iterable[data.Posting]=(),
*,
_can_own: bool=False,
) -> None:
self._postings: List[data.Posting]
if _can_own and isinstance(source, list):
self._postings = source
else:
self._postings = list(source)
@classmethod
def group_by_meta(cls: Type[RelatedType],
postings: Iterable[data.Posting],
key: MetaKey,
default: Optional[MetaValue]=None,
) -> Iterator[Tuple[Optional[MetaValue], RelatedType]]:
"""Relate postings by metadata value
This method takes an iterable of postings and returns a mapping.
The keys of the mapping are the values of post.meta.get(key, default).
The values are RelatedPostings instances that contain all the postings
that had that same metadata value.
"""
mapping: DefaultDict[Optional[MetaValue], List[data.Posting]] = collections.defaultdict(list)
for post in postings:
mapping[post.meta.get(key, default)].append(post)
for value, posts in mapping.items():
yield value, cls(posts, _can_own=True)
def __repr__(self) -> str:
return f'<{type(self).__name__} {self._postings!r}>'
@overload
def __getitem__(self: RelatedType, index: int) -> data.Posting: ...
@overload
def __getitem__(self: RelatedType, s: slice) -> RelatedType: ...
def __getitem__(self: RelatedType,
index: Union[int, slice],
) -> Union[data.Posting, RelatedType]:
if isinstance(index, slice):
return type(self)(self._postings[index], _can_own=True)
else:
return self._postings[index]
def __len__(self) -> int:
return len(self._postings)
def all_meta_links(self, key: MetaKey) -> Set[str]:
retval: Set[str] = set()
for post in self:
try:
retval.update(post.meta.get_links(key))
except TypeError:
pass
return retval
def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
2020-04-12 11:00:41 -04:00
balance = MutableBalance()
for post in self:
2020-04-12 11:00:41 -04:00
balance.add_amount(post.units)
yield post, balance
def balance(self) -> Balance:
for _, balance in self.iter_with_balance():
pass
try:
return balance
except NameError:
return Balance()
def balance_at_cost(self) -> Balance:
balance = MutableBalance()
for post in self:
if post.cost is None:
balance.add_amount(post.units)
else:
number = post.units.number * post.cost.number
balance.add_amount(data.Amount(number, post.cost.currency))
return balance
def meta_values(self,
key: MetaKey,
default: Optional[MetaValue]=None,
) -> Set[Optional[MetaValue]]:
return {post.meta.get(key, default) for post in self}