data: Introduce Posting class.

Our version of Posting is interface-compatible with Beancount's,
but makes stronger guarantees about the data types for our
higher-level code to rely on.
This commit is contained in:
Brett Smith 2020-03-18 07:48:08 -04:00
parent 163ecbc7d3
commit b2ef561c85
4 changed files with 35 additions and 42 deletions

View file

@ -21,13 +21,14 @@ from beancount.core import account as bc_account
from typing import ( from typing import (
Iterable, Iterable,
Iterator, Iterator,
MutableMapping,
Optional, Optional,
) )
from .beancount_types import ( from .beancount_types import (
MetaKey, MetaKey,
MetaValue, MetaValue,
Posting, Posting as BasePosting,
Transaction, Transaction,
) )
@ -43,7 +44,11 @@ class Account(str):
class PostingMeta(collections.abc.MutableMapping): class PostingMeta(collections.abc.MutableMapping):
def __init__(self, txn: Transaction, index: int, post: Optional[Posting]=None) -> None: def __init__(self,
txn: Transaction,
index: int,
post: Optional[BasePosting]=None,
) -> None:
if post is None: if post is None:
post = txn.postings[index] post = txn.postings[index]
self.txn = txn self.txn = txn
@ -83,9 +88,20 @@ class PostingMeta(collections.abc.MutableMapping):
del self.post.meta[key] del self.post.meta[key]
class Posting(BasePosting):
account: Account
# mypy correctly complains that our MutableMapping is not compatible with
# Beancount's meta type declaration of Optional[Dict]. IMO this is a case
# of Beancount's type declaration being a smidge too specific: I think it
# would be very unusual for code to actually require a dict over a more
# generic mapping. If it did, this would work fine.
meta: MutableMapping[MetaKey, MetaValue]
def iter_postings(txn: Transaction) -> Iterator[Posting]: def iter_postings(txn: Transaction) -> Iterator[Posting]:
for index, source in enumerate(txn.postings): for index, source in enumerate(txn.postings):
yield source._replace( yield Posting(
account=Account(source.account), Account(source.account),
meta=PostingMeta(txn, index, source), *source[1:5],
PostingMeta(txn, index, source),
) )

View file

@ -18,6 +18,7 @@ import abc
import datetime import datetime
import re import re
from .. import data
from .. import errors as errormod from .. import errors as errormod
from typing import ( from typing import (
@ -36,7 +37,6 @@ from ..beancount_types import (
MetaKey, MetaKey,
MetaValue, MetaValue,
MetaValueEnum, MetaValueEnum,
Posting,
Transaction, Transaction,
Type, Type,
) )
@ -177,43 +177,20 @@ class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:
cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting']) cls.HOOK_GROUPS = cls.HOOK_GROUPS.union(['posting'])
def _meta_get(self,
txn: Transaction,
post: Posting,
key: MetaKey,
default: MetaValue=None,
) -> MetaValue:
if post.meta and key in post.meta:
return post.meta[key]
else:
return txn.meta.get(key, default)
def _meta_set(self,
txn: Transaction,
post: Posting,
post_index: int,
key: MetaKey,
value: MetaValue,
) -> None:
if post.meta is None:
txn.postings[post_index] = Posting(*post[:5], {key: value})
else:
post.meta[key] = value
def _run_on_txn(self, txn: Transaction) -> bool: def _run_on_txn(self, txn: Transaction) -> bool:
return txn.date in self.TXN_DATE_RANGE return txn.date in self.TXN_DATE_RANGE
def _run_on_post(self, txn: Transaction, post: Posting) -> bool: def _run_on_post(self, txn: Transaction, post: data.Posting) -> bool:
return True return True
def run(self, txn: Transaction) -> errormod.Iter: def run(self, txn: Transaction) -> errormod.Iter:
if self._run_on_txn(txn): if self._run_on_txn(txn):
for index, post in enumerate(txn.postings): for post in data.iter_postings(txn):
if self._run_on_post(txn, post): if self._run_on_post(txn, post):
yield from self.post_run(txn, post, index) yield from self.post_run(txn, post)
@abc.abstractmethod @abc.abstractmethod
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: ... def post_run(self, txn: Transaction, post: data.Posting) -> errormod.Iter: ...
class _NormalizePostingMetadataHook(_PostingHook): class _NormalizePostingMetadataHook(_PostingHook):
@ -232,11 +209,11 @@ class _NormalizePostingMetadataHook(_PostingHook):
# _default_value to get a default. This method should either return # _default_value to get a default. This method should either return
# a value string from METADATA_ENUM, or else raise InvalidMetadataError. # a value string from METADATA_ENUM, or else raise InvalidMetadataError.
# This base implementation does the latter. # This base implementation does the latter.
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: def _default_value(self, txn: Transaction, post: data.Posting) -> MetaValueEnum:
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY) raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: def post_run(self, txn: Transaction, post: data.Posting) -> errormod.Iter:
source_value = self._meta_get(txn, post, self.METADATA_KEY) source_value = post.meta.get(self.METADATA_KEY)
set_value = source_value set_value = source_value
error: Optional[errormod.Error] = None error: Optional[errormod.Error] = None
if source_value is None: if source_value is None:
@ -252,6 +229,6 @@ class _NormalizePostingMetadataHook(_PostingHook):
txn, post, self.METADATA_KEY, source_value, txn, post, self.METADATA_KEY, source_value,
) )
if error is None: if error is None:
self._meta_set(txn, post, post_index, self.METADATA_KEY, set_value) post.meta[self.METADATA_KEY] = set_value
else: else:
yield error yield error

View file

@ -15,9 +15,9 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from . import core from . import core
from .. import data
from ..beancount_types import ( from ..beancount_types import (
MetaValueEnum, MetaValueEnum,
Posting,
Transaction, Transaction,
) )
@ -35,8 +35,8 @@ class MetaExpenseAllocation(core._NormalizePostingMetadataHook):
'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'], 'Expenses:Services:Fundraising': VALUES_ENUM['fundraising'],
} }
def _run_on_post(self, txn: Transaction, post: Posting) -> bool: def _run_on_post(self, txn: Transaction, post: data.Posting) -> bool:
return post.account.startswith('Expenses:') return post.account.startswith('Expenses:')
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: def _default_value(self, txn: Transaction, post: data.Posting) -> MetaValueEnum:
return self.DEFAULT_VALUES.get(post.account, 'program') return self.DEFAULT_VALUES.get(post.account, 'program')

View file

@ -17,8 +17,8 @@
import decimal import decimal
from . import core from . import core
from .. import data
from ..beancount_types import ( from ..beancount_types import (
Posting,
Transaction, Transaction,
) )
@ -45,7 +45,7 @@ class MetaTaxImplication(core._NormalizePostingMetadataHook):
'W2', 'W2',
], {}) ], {})
def _run_on_post(self, txn: Transaction, post: Posting) -> bool: def _run_on_post(self, txn: Transaction, post: data.Posting) -> bool:
return bool( return bool(
post.account.startswith('Assets:') post.account.startswith('Assets:')
and post.units.number and post.units.number