Improve organization between modules.
* Rename _typing to beancount_types to better reflect what it is. * LessComparable isn't a Beancount type, so move that to plugin.core with its dependent helper classes. * Errors are a core Beancount concept, so move that module to the top level and have it include appropriate type definitions.
This commit is contained in:
parent
a41feb94b3
commit
3fbc14d377
7 changed files with 35 additions and 32 deletions
|
@ -1,4 +1,4 @@
|
|||
"""Type definitions for Conservancy Beancount code"""
|
||||
"""Type definitions for Beancount data structures"""
|
||||
# Copyright © 2020 Brett Smith
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
|
@ -18,7 +18,6 @@ import abc
|
|||
import datetime
|
||||
|
||||
import beancount.core.data as bc_data
|
||||
from .plugin import errors
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
|
@ -33,21 +32,11 @@ from typing import (
|
|||
)
|
||||
|
||||
Account = bc_data.Account
|
||||
Error = errors._BaseError
|
||||
ErrorIter = Iterable[Error]
|
||||
MetaKey = str
|
||||
MetaValue = Any
|
||||
MetaValueEnum = str
|
||||
Posting = bc_data.Posting
|
||||
|
||||
class LessComparable(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __le__(self, other: Any) -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __lt__(self, other: Any) -> bool: ...
|
||||
|
||||
|
||||
class Directive(NamedTuple):
|
||||
meta: bc_data.Meta
|
||||
date: datetime.date
|
|
@ -14,7 +14,11 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
class _BaseError(Exception):
|
||||
from typing import (
|
||||
Iterable,
|
||||
)
|
||||
|
||||
class Error(Exception):
|
||||
def __init__(self, message, entry, source=None):
|
||||
self.message = message
|
||||
self.entry = entry
|
||||
|
@ -28,7 +32,9 @@ class _BaseError(Exception):
|
|||
)
|
||||
|
||||
|
||||
class InvalidMetadataError(_BaseError):
|
||||
Iter = Iterable[Error]
|
||||
|
||||
class InvalidMetadataError(Error):
|
||||
def __init__(self, txn, post, key, value=None, source=None):
|
||||
if value is None:
|
||||
msg_fmt = "{post.account} missing {key}"
|
|
@ -28,15 +28,17 @@ from typing import (
|
|||
Tuple,
|
||||
Type,
|
||||
)
|
||||
from .._typing import (
|
||||
from ..beancount_types import (
|
||||
ALL_DIRECTIVES,
|
||||
Directive,
|
||||
Error,
|
||||
)
|
||||
from .core import (
|
||||
Hook,
|
||||
HookName,
|
||||
)
|
||||
from ..errors import (
|
||||
Error,
|
||||
)
|
||||
|
||||
__plugins__ = ['run']
|
||||
|
||||
|
|
|
@ -18,9 +18,10 @@ import abc
|
|||
import datetime
|
||||
import re
|
||||
|
||||
from . import errors as errormod
|
||||
from .. import errors as errormod
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
FrozenSet,
|
||||
Generic,
|
||||
Iterable,
|
||||
|
@ -29,12 +30,9 @@ from typing import (
|
|||
Optional,
|
||||
TypeVar,
|
||||
)
|
||||
from .._typing import (
|
||||
from ..beancount_types import (
|
||||
Account,
|
||||
Directive,
|
||||
Error,
|
||||
ErrorIter,
|
||||
LessComparable,
|
||||
MetaKey,
|
||||
MetaValue,
|
||||
MetaValueEnum,
|
||||
|
@ -62,7 +60,7 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
|
|||
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, entry: Entry) -> ErrorIter: ...
|
||||
def run(self, entry: Entry) -> errormod.Iter: ...
|
||||
|
||||
def __init_subclass__(cls):
|
||||
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
|
||||
|
@ -72,6 +70,14 @@ TransactionHook = Hook[Transaction]
|
|||
|
||||
### HELPER CLASSES
|
||||
|
||||
class LessComparable(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __le__(self, other: Any) -> bool: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __lt__(self, other: Any) -> bool: ...
|
||||
|
||||
|
||||
CT = TypeVar('CT', bound=LessComparable)
|
||||
class _GenericRange(Generic[CT]):
|
||||
"""Convenience class to check whether a value is within a range.
|
||||
|
@ -200,14 +206,14 @@ class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
|
|||
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
||||
return True
|
||||
|
||||
def run(self, txn: Transaction) -> ErrorIter:
|
||||
def run(self, txn: Transaction) -> errormod.Iter:
|
||||
if self._run_on_txn(txn):
|
||||
for index, post in enumerate(txn.postings):
|
||||
if self._run_on_post(txn, post):
|
||||
yield from self.post_run(txn, post, index)
|
||||
|
||||
@abc.abstractmethod
|
||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
|
||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: ...
|
||||
|
||||
|
||||
class _NormalizePostingMetadataHook(_PostingHook):
|
||||
|
@ -229,14 +235,14 @@ class _NormalizePostingMetadataHook(_PostingHook):
|
|||
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
||||
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
||||
|
||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
|
||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter:
|
||||
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
||||
set_value = source_value
|
||||
error: Optional[Error] = None
|
||||
error: Optional[errormod.Error] = None
|
||||
if source_value is None:
|
||||
try:
|
||||
set_value = self._default_value(txn, post)
|
||||
except errormod._BaseError as error_:
|
||||
except errormod.Error as error_:
|
||||
error = error_
|
||||
else:
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
from . import core
|
||||
from .._typing import (
|
||||
from ..beancount_types import (
|
||||
MetaValueEnum,
|
||||
Posting,
|
||||
Transaction,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import decimal
|
||||
|
||||
from . import core
|
||||
from .._typing import (
|
||||
from ..beancount_types import (
|
||||
Posting,
|
||||
Transaction,
|
||||
)
|
||||
|
|
|
@ -18,14 +18,14 @@ import pytest
|
|||
|
||||
from . import testutil
|
||||
|
||||
from conservancy_beancount import plugin, _typing
|
||||
from conservancy_beancount import beancount_types, plugin
|
||||
|
||||
CONFIG_MAP = {}
|
||||
HOOK_REGISTRY = plugin.HookRegistry()
|
||||
|
||||
@HOOK_REGISTRY.add_hook
|
||||
class TransactionCounter:
|
||||
DIRECTIVE = _typing.Transaction
|
||||
DIRECTIVE = beancount_types.Transaction
|
||||
HOOK_GROUPS = frozenset()
|
||||
|
||||
def run(self, txn):
|
||||
|
@ -34,7 +34,7 @@ class TransactionCounter:
|
|||
|
||||
@HOOK_REGISTRY.add_hook
|
||||
class PostingCounter(TransactionCounter):
|
||||
DIRECTIVE = _typing.Transaction
|
||||
DIRECTIVE = beancount_types.Transaction
|
||||
HOOK_GROUPS = frozenset(['posting'])
|
||||
|
||||
def run(self, txn):
|
||||
|
|
Loading…
Reference in a new issue