247 lines
7.7 KiB
Python
247 lines
7.7 KiB
Python
"""core.py - Common data classes for reporting functionality"""
|
|
# Copyright © 2020 Brett Smith
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import collections
|
|
import operator
|
|
|
|
from decimal import Decimal
|
|
|
|
import babel.numbers # type:ignore[import]
|
|
|
|
from .. import data
|
|
|
|
from typing import (
|
|
overload,
|
|
Any,
|
|
Callable,
|
|
DefaultDict,
|
|
Dict,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
from ..beancount_types import (
|
|
MetaKey,
|
|
MetaValue,
|
|
)
|
|
|
|
DecimalCompat = data.DecimalCompat
|
|
|
|
class Balance(Mapping[str, data.Amount]):
|
|
"""A collection of amounts mapped by currency
|
|
|
|
Each key is a Beancount currency string, and each value represents the
|
|
balance in that currency.
|
|
"""
|
|
__slots__ = ('_currency_map',)
|
|
|
|
def __init__(self,
|
|
source: Union[Iterable[Tuple[str, data.Amount]],
|
|
Mapping[str, data.Amount]]=(),
|
|
) -> None:
|
|
if isinstance(source, Mapping):
|
|
source = source.items()
|
|
self._currency_map = {
|
|
currency: amount.number for currency, amount in source
|
|
}
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{type(self).__name__}({self._currency_map!r})"
|
|
|
|
def __str__(self) -> str:
|
|
return self.format()
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if (self.is_zero()
|
|
and isinstance(other, Balance)
|
|
and other.is_zero()):
|
|
return True
|
|
else:
|
|
return super().__eq__(other)
|
|
|
|
def __neg__(self) -> 'Balance':
|
|
return type(self)(
|
|
(key, -amt) for key, amt in self.items()
|
|
)
|
|
|
|
def __getitem__(self, key: str) -> data.Amount:
|
|
return data.Amount(self._currency_map[key], key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
return iter(self._currency_map)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._currency_map)
|
|
|
|
def _all_amounts(self,
|
|
op_func: Callable[[DecimalCompat, DecimalCompat], bool],
|
|
operand: DecimalCompat,
|
|
) -> bool:
|
|
return all(op_func(number, operand) for number in self._currency_map.values())
|
|
|
|
def eq_zero(self) -> bool:
|
|
"""Returns true if all amounts in the balance == 0."""
|
|
return self._all_amounts(operator.eq, 0)
|
|
|
|
is_zero = eq_zero
|
|
|
|
def ge_zero(self) -> bool:
|
|
"""Returns true if all amounts in the balance >= 0."""
|
|
return self._all_amounts(operator.ge, 0)
|
|
|
|
def le_zero(self) -> bool:
|
|
"""Returns true if all amounts in the balance <= 0."""
|
|
return self._all_amounts(operator.le, 0)
|
|
|
|
def format(self,
|
|
fmt: Optional[str]='#,#00.00 ¤¤',
|
|
sep: str=', ',
|
|
empty: str="Zero balance",
|
|
) -> str:
|
|
"""Formats the balance as a string with the given parameters
|
|
|
|
If the balance is zero, returns ``empty``. Otherwise, returns a string
|
|
with each amount in the balance formatted as ``fmt``, separated by
|
|
``sep``.
|
|
|
|
If you set ``fmt`` to None, amounts will be formatted according to the
|
|
user's locale. The default format is Beancount's input format.
|
|
"""
|
|
amounts = [amount for amount in self.values() if amount.number]
|
|
if not amounts:
|
|
return empty
|
|
amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
|
|
return sep.join(
|
|
babel.numbers.format_currency(amt.number, amt.currency, fmt)
|
|
for amt in amounts
|
|
)
|
|
|
|
|
|
class MutableBalance(Balance):
|
|
__slots__ = ()
|
|
|
|
def add_amount(self, amount: data.Amount) -> None:
|
|
try:
|
|
self._currency_map[amount.currency] += amount.number
|
|
except KeyError:
|
|
self._currency_map[amount.currency] = amount.number
|
|
|
|
|
|
class RelatedPostings(Sequence[data.Posting]):
|
|
"""Collect and query related postings
|
|
|
|
This class provides common functionality for collecting related postings
|
|
and running queries on them: iterating over them, tallying their balance,
|
|
etc.
|
|
|
|
This class doesn't know anything about how the postings are related. That's
|
|
entirely up to the caller.
|
|
|
|
A common pattern is to use this class with collections.defaultdict
|
|
to organize postings based on some key. See the group_by_meta classmethod
|
|
for an example.
|
|
"""
|
|
__slots__ = ('_postings',)
|
|
|
|
def __init__(self, source: Iterable[data.Posting]=()) -> None:
|
|
self._postings: List[data.Posting] = list(source)
|
|
|
|
@classmethod
|
|
def group_by_meta(cls,
|
|
postings: Iterable[data.Posting],
|
|
key: MetaKey,
|
|
default: Optional[MetaValue]=None,
|
|
) -> Mapping[Optional[MetaValue], 'RelatedPostings']:
|
|
"""Relate postings by metadata value
|
|
|
|
This method takes an iterable of postings and returns a mapping.
|
|
The keys of the mapping are the values of post.meta.get(key, default).
|
|
The values are RelatedPostings instances that contain all the postings
|
|
that had that same metadata value.
|
|
"""
|
|
retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls)
|
|
for post in postings:
|
|
retval[post.meta.get(key, default)].add(post)
|
|
retval.default_factory = None
|
|
return retval
|
|
|
|
@overload
|
|
def __getitem__(self, index: int) -> data.Posting: ...
|
|
|
|
@overload
|
|
def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...
|
|
|
|
def __getitem__(self,
|
|
index: Union[int, slice],
|
|
) -> Union[data.Posting, Sequence[data.Posting]]:
|
|
if isinstance(index, slice):
|
|
raise NotImplementedError("RelatedPostings[slice]")
|
|
else:
|
|
return self._postings[index]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._postings)
|
|
|
|
def add(self, post: data.Posting) -> None:
|
|
self._postings.append(post)
|
|
|
|
def all_meta_links(self, key: MetaKey) -> Set[str]:
|
|
retval: Set[str] = set()
|
|
for post in self:
|
|
try:
|
|
retval.update(post.meta.get_links(key))
|
|
except TypeError:
|
|
pass
|
|
return retval
|
|
|
|
def clear(self) -> None:
|
|
self._postings.clear()
|
|
|
|
def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
|
|
balance = MutableBalance()
|
|
for post in self:
|
|
balance.add_amount(post.units)
|
|
yield post, balance
|
|
|
|
def balance(self) -> Balance:
|
|
for _, balance in self.iter_with_balance():
|
|
pass
|
|
try:
|
|
return balance
|
|
except NameError:
|
|
return Balance()
|
|
|
|
def balance_at_cost(self) -> Balance:
|
|
balance = MutableBalance()
|
|
for post in self:
|
|
if post.cost is None:
|
|
balance.add_amount(post.units)
|
|
else:
|
|
number = post.units.number * post.cost.number
|
|
balance.add_amount(data.Amount(number, post.cost.currency))
|
|
return balance
|
|
|
|
def meta_values(self,
|
|
key: MetaKey,
|
|
default: Optional[MetaValue]=None,
|
|
) -> Set[Optional[MetaValue]]:
|
|
return {post.meta.get(key, default) for post in self}
|