conservancy_beancount/conservancy_beancount/reports/core.py
2020-04-29 11:37:38 -04:00

180 lines
5.6 KiB
Python

"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import collections
from decimal import Decimal
from .. import data
from typing import (
overload,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from ..beancount_types import (
MetaKey,
MetaValue,
)
class Balance(Mapping[str, data.Amount]):
"""A collection of amounts mapped by currency
Each key is a Beancount currency string, and each value represents the
balance in that currency.
"""
__slots__ = ('_currency_map',)
def __init__(self,
source: Union[Iterable[Tuple[str, data.Amount]],
Mapping[str, data.Amount]]=(),
) -> None:
if isinstance(source, Mapping):
source = source.items()
self._currency_map = {
currency: amount.number for currency, amount in source
}
def __repr__(self) -> str:
return f"{type(self).__name__}({self._currency_map!r})"
def __str__(self) -> str:
amounts = [amount for amount in self.values() if amount.number]
if not amounts:
return "Zero balance"
amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
return ', '.join(str(amount) for amount in amounts)
def __getitem__(self, key: str) -> data.Amount:
return data.Amount(self._currency_map[key], key)
def __iter__(self) -> Iterator[str]:
return iter(self._currency_map)
def __len__(self) -> int:
return len(self._currency_map)
def is_zero(self) -> bool:
return all(number == 0 for number in self._currency_map.values())
class MutableBalance(Balance):
__slots__ = ()
def add_amount(self, amount: data.Amount) -> None:
try:
self._currency_map[amount.currency] += amount.number
except KeyError:
self._currency_map[amount.currency] = amount.number
class RelatedPostings(Sequence[data.Posting]):
"""Collect and query related postings
This class provides common functionality for collecting related postings
and running queries on them: iterating over them, tallying their balance,
etc.
This class doesn't know anything about how the postings are related. That's
entirely up to the caller.
A common pattern is to use this class with collections.defaultdict
to organize postings based on some key. See the group_by_meta classmethod
for an example.
"""
def __init__(self, source: Iterable[data.Posting]=()) -> None:
self._postings: List[data.Posting] = list(source)
@classmethod
def group_by_meta(cls,
postings: Iterable[data.Posting],
key: MetaKey,
default: Optional[MetaValue]=None,
) -> Mapping[Optional[MetaValue], 'RelatedPostings']:
"""Relate postings by metadata value
This method takes an iterable of postings and returns a mapping.
The keys of the mapping are the values of post.meta.get(key, default).
The values are RelatedPostings instances that contain all the postings
that had that same metadata value.
"""
retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls)
for post in postings:
retval[post.meta.get(key, default)].add(post)
retval.default_factory = None
return retval
@overload
def __getitem__(self, index: int) -> data.Posting: ...
@overload
def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...
def __getitem__(self,
index: Union[int, slice],
) -> Union[data.Posting, Sequence[data.Posting]]:
if isinstance(index, slice):
raise NotImplementedError("RelatedPostings[slice]")
else:
return self._postings[index]
def __len__(self) -> int:
return len(self._postings)
def add(self, post: data.Posting) -> None:
self._postings.append(post)
def all_meta_links(self, key: MetaKey) -> Set[str]:
retval: Set[str] = set()
for post in self:
try:
retval.update(post.meta.get_links(key))
except TypeError:
pass
return retval
def clear(self) -> None:
self._postings.clear()
def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
balance = MutableBalance()
for post in self:
balance.add_amount(post.units)
yield post, balance
def balance(self) -> Balance:
for _, balance in self.iter_with_balance():
pass
try:
return balance
except NameError:
return Balance()
def meta_values(self,
key: MetaKey,
default: Optional[MetaValue]=None,
) -> Set[Optional[MetaValue]]:
return {post.meta.get(key, default) for post in self}