180 lines
5.6 KiB
Python
180 lines
5.6 KiB
Python
"""core.py - Common data classes for reporting functionality"""
|
|
# Copyright © 2020 Brett Smith
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import collections
|
|
|
|
from decimal import Decimal
|
|
|
|
from .. import data
|
|
|
|
from typing import (
|
|
overload,
|
|
DefaultDict,
|
|
Dict,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
from ..beancount_types import (
|
|
MetaKey,
|
|
MetaValue,
|
|
)
|
|
|
|
class Balance(Mapping[str, data.Amount]):
|
|
"""A collection of amounts mapped by currency
|
|
|
|
Each key is a Beancount currency string, and each value represents the
|
|
balance in that currency.
|
|
"""
|
|
__slots__ = ('_currency_map',)
|
|
|
|
def __init__(self,
|
|
source: Union[Iterable[Tuple[str, data.Amount]],
|
|
Mapping[str, data.Amount]]=(),
|
|
) -> None:
|
|
if isinstance(source, Mapping):
|
|
source = source.items()
|
|
self._currency_map = {
|
|
currency: amount.number for currency, amount in source
|
|
}
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{type(self).__name__}({self._currency_map!r})"
|
|
|
|
def __str__(self) -> str:
|
|
amounts = [amount for amount in self.values() if amount.number]
|
|
if not amounts:
|
|
return "Zero balance"
|
|
amounts.sort(key=lambda amt: abs(amt.number), reverse=True)
|
|
return ', '.join(str(amount) for amount in amounts)
|
|
|
|
def __getitem__(self, key: str) -> data.Amount:
|
|
return data.Amount(self._currency_map[key], key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
return iter(self._currency_map)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._currency_map)
|
|
|
|
def is_zero(self) -> bool:
|
|
return all(number == 0 for number in self._currency_map.values())
|
|
|
|
|
|
class MutableBalance(Balance):
|
|
__slots__ = ()
|
|
|
|
def add_amount(self, amount: data.Amount) -> None:
|
|
try:
|
|
self._currency_map[amount.currency] += amount.number
|
|
except KeyError:
|
|
self._currency_map[amount.currency] = amount.number
|
|
|
|
|
|
class RelatedPostings(Sequence[data.Posting]):
|
|
"""Collect and query related postings
|
|
|
|
This class provides common functionality for collecting related postings
|
|
and running queries on them: iterating over them, tallying their balance,
|
|
etc.
|
|
|
|
This class doesn't know anything about how the postings are related. That's
|
|
entirely up to the caller.
|
|
|
|
A common pattern is to use this class with collections.defaultdict
|
|
to organize postings based on some key. See the group_by_meta classmethod
|
|
for an example.
|
|
"""
|
|
|
|
def __init__(self, source: Iterable[data.Posting]=()) -> None:
|
|
self._postings: List[data.Posting] = list(source)
|
|
|
|
@classmethod
|
|
def group_by_meta(cls,
|
|
postings: Iterable[data.Posting],
|
|
key: MetaKey,
|
|
default: Optional[MetaValue]=None,
|
|
) -> Mapping[Optional[MetaValue], 'RelatedPostings']:
|
|
"""Relate postings by metadata value
|
|
|
|
This method takes an iterable of postings and returns a mapping.
|
|
The keys of the mapping are the values of post.meta.get(key, default).
|
|
The values are RelatedPostings instances that contain all the postings
|
|
that had that same metadata value.
|
|
"""
|
|
retval: DefaultDict[Optional[MetaValue], 'RelatedPostings'] = collections.defaultdict(cls)
|
|
for post in postings:
|
|
retval[post.meta.get(key, default)].add(post)
|
|
retval.default_factory = None
|
|
return retval
|
|
|
|
@overload
|
|
def __getitem__(self, index: int) -> data.Posting: ...
|
|
|
|
@overload
|
|
def __getitem__(self, s: slice) -> Sequence[data.Posting]: ...
|
|
|
|
def __getitem__(self,
|
|
index: Union[int, slice],
|
|
) -> Union[data.Posting, Sequence[data.Posting]]:
|
|
if isinstance(index, slice):
|
|
raise NotImplementedError("RelatedPostings[slice]")
|
|
else:
|
|
return self._postings[index]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._postings)
|
|
|
|
def add(self, post: data.Posting) -> None:
|
|
self._postings.append(post)
|
|
|
|
def all_meta_links(self, key: MetaKey) -> Set[str]:
|
|
retval: Set[str] = set()
|
|
for post in self:
|
|
try:
|
|
retval.update(post.meta.get_links(key))
|
|
except TypeError:
|
|
pass
|
|
return retval
|
|
|
|
def clear(self) -> None:
|
|
self._postings.clear()
|
|
|
|
def iter_with_balance(self) -> Iterator[Tuple[data.Posting, Balance]]:
|
|
balance = MutableBalance()
|
|
for post in self:
|
|
balance.add_amount(post.units)
|
|
yield post, balance
|
|
|
|
def balance(self) -> Balance:
|
|
for _, balance in self.iter_with_balance():
|
|
pass
|
|
try:
|
|
return balance
|
|
except NameError:
|
|
return Balance()
|
|
|
|
def meta_values(self,
|
|
key: MetaKey,
|
|
default: Optional[MetaValue]=None,
|
|
) -> Set[Optional[MetaValue]]:
|
|
return {post.meta.get(key, default) for post in self}
|