conservancy_beancount/conservancy_beancount/reports/core.py

151 lines
4.4 KiB
Python
Raw Normal View History

"""core.py - Common data classes for reporting functionality"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import collections
from decimal import Decimal
from .. import data
from typing import (
overload,
Dict,
2020-04-12 11:00:41 -04:00
Iterable,
Iterator,
List,
2020-04-12 11:00:41 -04:00
Mapping,
Optional,
Sequence,
Set,
2020-04-12 11:00:41 -04:00
Tuple,
Union,
)
from ..beancount_types import (
MetaKey,
MetaValue,
)
2020-04-12 11:00:41 -04:00
class Balance(Mapping[str, data.Amount]):
"""A collection of amounts mapped by currency
Each key is a Beancount currency string, and each value represents the
balance in that currency.
"""
__slots__ = ('_currency_map',)
def __init__(self,
source: Union[Iterable[Tuple[str, data.Amount]],
Mapping[str, data.Amount]]=(),
) -> None:
if isinstance(source, Mapping):
source = source.items()
self._currency_map = {
currency: amount.number for currency, amount in source
}
def __repr__(self) -> str:
return f"{type(self).__name__}({self._currency_map!r})"
def __getitem__(self, key: str) -> data.Amount:
return data.Amount(self._currency_map[key], key)
def __iter__(self) -> Iterator[str]:
return iter(self._currency_map)
def __len__(self) -> int:
return len(self._currency_map)
def is_zero(self) -> bool:
return all(number == 0 for number in self._currency_map.values())
class MutableBalance(Balance):
__slots__ = ()
def add_amount(self, amount: data.Amount) -> None:
try:
self._currency_map[amount.currency] += amount.number
except KeyError:
self._currency_map[amount.currency] = amount.number
class RelatedPostings(Sequence[data.Posting]):
"""Collect and query related postings
This class provides common functionality for collecting related postings
and running queries on them: iterating over them, tallying their balance,
etc.
This class doesn't know anything about how the postings are related. That's
entirely up to the caller.
A common pattern is to use this class with collections.defaultdict
to organize postings based on some key::
report = collections.defaultdict(RelatedPostings)
for txn in transactions:
for post in Posting.from_txn(txn):
if should_report(post):
key = post_key(post)
report[key].add(post)
"""
def __init__(self) -> None:
self._postings: List[data.Posting] = []
@overload
def __getitem__(self, index: int) -> data.Posting: ...
@overload
def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...
def __getitem__(self,
index: Union[int, slice],
) -> Union[data.Posting, Sequence[data.Posting]]:
if isinstance(index, slice):
raise NotImplementedError("RelatedPostings[slice]")
else:
return self._postings[index]
def __len__(self) -> int:
return len(self._postings)
def add(self, post: data.Posting) -> None:
self._postings.append(post)
def clear(self) -> None:
self._postings.clear()
def iter_with_balance(self) -> Iterable[Tuple[data.Posting, Balance]]:
2020-04-12 11:00:41 -04:00
balance = MutableBalance()
for post in self:
2020-04-12 11:00:41 -04:00
balance.add_amount(post.units)
yield post, balance
def balance(self) -> Balance:
for _, balance in self.iter_with_balance():
pass
try:
return balance
except NameError:
return Balance()
def meta_values(self,
key: MetaKey,
default: Optional[MetaValue]=None,
) -> Set[Optional[MetaValue]]:
return {post.meta.get(key, default) for post in self}