Improve organization between modules.

* Rename _typing to beancount_types to better reflect what it is.
* LessComparable isn't a Beancount type, so move that to
  plugin.core with its dependent helper classes.
* Errors are a core Beancount concept, so move that module to the
  top level and have it include appropriate type definitions.
This commit is contained in:
Brett Smith 2020-03-15 16:01:40 -04:00
parent a41feb94b3
commit 3fbc14d377
7 changed files with 35 additions and 32 deletions

View file

@ -1,4 +1,4 @@
"""Type definitions for Conservancy Beancount code""" """Type definitions for Beancount data structures"""
# Copyright © 2020 Brett Smith # Copyright © 2020 Brett Smith
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
@ -18,7 +18,6 @@ import abc
import datetime import datetime
import beancount.core.data as bc_data import beancount.core.data as bc_data
from .plugin import errors
from typing import ( from typing import (
Any, Any,
@ -33,21 +32,11 @@ from typing import (
) )
Account = bc_data.Account Account = bc_data.Account
Error = errors._BaseError
ErrorIter = Iterable[Error]
MetaKey = str MetaKey = str
MetaValue = Any MetaValue = Any
MetaValueEnum = str MetaValueEnum = str
Posting = bc_data.Posting Posting = bc_data.Posting
class LessComparable(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __le__(self, other: Any) -> bool: ...
@abc.abstractmethod
def __lt__(self, other: Any) -> bool: ...
class Directive(NamedTuple): class Directive(NamedTuple):
meta: bc_data.Meta meta: bc_data.Meta
date: datetime.date date: datetime.date

View file

@ -14,7 +14,11 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
class _BaseError(Exception): from typing import (
Iterable,
)
class Error(Exception):
def __init__(self, message, entry, source=None): def __init__(self, message, entry, source=None):
self.message = message self.message = message
self.entry = entry self.entry = entry
@ -28,7 +32,9 @@ class _BaseError(Exception):
) )
class InvalidMetadataError(_BaseError): Iter = Iterable[Error]
class InvalidMetadataError(Error):
def __init__(self, txn, post, key, value=None, source=None): def __init__(self, txn, post, key, value=None, source=None):
if value is None: if value is None:
msg_fmt = "{post.account} missing {key}" msg_fmt = "{post.account} missing {key}"

View file

@ -28,15 +28,17 @@ from typing import (
Tuple, Tuple,
Type, Type,
) )
from .._typing import ( from ..beancount_types import (
ALL_DIRECTIVES, ALL_DIRECTIVES,
Directive, Directive,
Error,
) )
from .core import ( from .core import (
Hook, Hook,
HookName, HookName,
) )
from ..errors import (
Error,
)
__plugins__ = ['run'] __plugins__ = ['run']

View file

@ -18,9 +18,10 @@ import abc
import datetime import datetime
import re import re
from . import errors as errormod from .. import errors as errormod
from typing import ( from typing import (
Any,
FrozenSet, FrozenSet,
Generic, Generic,
Iterable, Iterable,
@ -29,12 +30,9 @@ from typing import (
Optional, Optional,
TypeVar, TypeVar,
) )
from .._typing import ( from ..beancount_types import (
Account, Account,
Directive, Directive,
Error,
ErrorIter,
LessComparable,
MetaKey, MetaKey,
MetaValue, MetaValue,
MetaValueEnum, MetaValueEnum,
@ -62,7 +60,7 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
HOOK_GROUPS: FrozenSet[HookName] = frozenset() HOOK_GROUPS: FrozenSet[HookName] = frozenset()
@abc.abstractmethod @abc.abstractmethod
def run(self, entry: Entry) -> ErrorIter: ... def run(self, entry: Entry) -> errormod.Iter: ...
def __init_subclass__(cls): def __init_subclass__(cls):
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0] cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
@ -72,6 +70,14 @@ TransactionHook = Hook[Transaction]
### HELPER CLASSES ### HELPER CLASSES
class LessComparable(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __le__(self, other: Any) -> bool: ...
@abc.abstractmethod
def __lt__(self, other: Any) -> bool: ...
CT = TypeVar('CT', bound=LessComparable) CT = TypeVar('CT', bound=LessComparable)
class _GenericRange(Generic[CT]): class _GenericRange(Generic[CT]):
"""Convenience class to check whether a value is within a range. """Convenience class to check whether a value is within a range.
@ -200,14 +206,14 @@ class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
def _run_on_post(self, txn: Transaction, post: Posting) -> bool: def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
return True return True
def run(self, txn: Transaction) -> ErrorIter: def run(self, txn: Transaction) -> errormod.Iter:
if self._run_on_txn(txn): if self._run_on_txn(txn):
for index, post in enumerate(txn.postings): for index, post in enumerate(txn.postings):
if self._run_on_post(txn, post): if self._run_on_post(txn, post):
yield from self.post_run(txn, post, index) yield from self.post_run(txn, post, index)
@abc.abstractmethod @abc.abstractmethod
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ... def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: ...
class _NormalizePostingMetadataHook(_PostingHook): class _NormalizePostingMetadataHook(_PostingHook):
@ -229,14 +235,14 @@ class _NormalizePostingMetadataHook(_PostingHook):
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum: def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY) raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter:
source_value = self._meta_get(txn, post, self.METADATA_KEY) source_value = self._meta_get(txn, post, self.METADATA_KEY)
set_value = source_value set_value = source_value
error: Optional[Error] = None error: Optional[errormod.Error] = None
if source_value is None: if source_value is None:
try: try:
set_value = self._default_value(txn, post) set_value = self._default_value(txn, post)
except errormod._BaseError as error_: except errormod.Error as error_:
error = error_ error = error_
else: else:
try: try:

View file

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from . import core from . import core
from .._typing import ( from ..beancount_types import (
MetaValueEnum, MetaValueEnum,
Posting, Posting,
Transaction, Transaction,

View file

@ -17,7 +17,7 @@
import decimal import decimal
from . import core from . import core
from .._typing import ( from ..beancount_types import (
Posting, Posting,
Transaction, Transaction,
) )

View file

@ -18,14 +18,14 @@ import pytest
from . import testutil from . import testutil
from conservancy_beancount import plugin, _typing from conservancy_beancount import beancount_types, plugin
CONFIG_MAP = {} CONFIG_MAP = {}
HOOK_REGISTRY = plugin.HookRegistry() HOOK_REGISTRY = plugin.HookRegistry()
@HOOK_REGISTRY.add_hook @HOOK_REGISTRY.add_hook
class TransactionCounter: class TransactionCounter:
DIRECTIVE = _typing.Transaction DIRECTIVE = beancount_types.Transaction
HOOK_GROUPS = frozenset() HOOK_GROUPS = frozenset()
def run(self, txn): def run(self, txn):
@ -34,7 +34,7 @@ class TransactionCounter:
@HOOK_REGISTRY.add_hook @HOOK_REGISTRY.add_hook
class PostingCounter(TransactionCounter): class PostingCounter(TransactionCounter):
DIRECTIVE = _typing.Transaction DIRECTIVE = beancount_types.Transaction
HOOK_GROUPS = frozenset(['posting']) HOOK_GROUPS = frozenset(['posting'])
def run(self, txn): def run(self, txn):