Improve organization between modules.
* Rename _typing to beancount_types to better reflect what it is. * LessComparable isn't a Beancount type, so move that to plugin.core with its dependent helper classes. * Errors are a core Beancount concept, so move that module to the top level and have it include appropriate type definitions.
This commit is contained in:
parent
a41feb94b3
commit
3fbc14d377
7 changed files with 35 additions and 32 deletions
|
@ -1,4 +1,4 @@
|
||||||
"""Type definitions for Conservancy Beancount code"""
|
"""Type definitions for Beancount data structures"""
|
||||||
# Copyright © 2020 Brett Smith
|
# Copyright © 2020 Brett Smith
|
||||||
#
|
#
|
||||||
# This program is free software: you can redistribute it and/or modify
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
@ -18,7 +18,6 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import beancount.core.data as bc_data
|
import beancount.core.data as bc_data
|
||||||
from .plugin import errors
|
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -33,21 +32,11 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
Account = bc_data.Account
|
Account = bc_data.Account
|
||||||
Error = errors._BaseError
|
|
||||||
ErrorIter = Iterable[Error]
|
|
||||||
MetaKey = str
|
MetaKey = str
|
||||||
MetaValue = Any
|
MetaValue = Any
|
||||||
MetaValueEnum = str
|
MetaValueEnum = str
|
||||||
Posting = bc_data.Posting
|
Posting = bc_data.Posting
|
||||||
|
|
||||||
class LessComparable(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def __le__(self, other: Any) -> bool: ...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def __lt__(self, other: Any) -> bool: ...
|
|
||||||
|
|
||||||
|
|
||||||
class Directive(NamedTuple):
|
class Directive(NamedTuple):
|
||||||
meta: bc_data.Meta
|
meta: bc_data.Meta
|
||||||
date: datetime.date
|
date: datetime.date
|
|
@ -14,7 +14,11 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
class _BaseError(Exception):
|
from typing import (
|
||||||
|
Iterable,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Error(Exception):
|
||||||
def __init__(self, message, entry, source=None):
|
def __init__(self, message, entry, source=None):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
|
@ -28,7 +32,9 @@ class _BaseError(Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidMetadataError(_BaseError):
|
Iter = Iterable[Error]
|
||||||
|
|
||||||
|
class InvalidMetadataError(Error):
|
||||||
def __init__(self, txn, post, key, value=None, source=None):
|
def __init__(self, txn, post, key, value=None, source=None):
|
||||||
if value is None:
|
if value is None:
|
||||||
msg_fmt = "{post.account} missing {key}"
|
msg_fmt = "{post.account} missing {key}"
|
|
@ -28,15 +28,17 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from .._typing import (
|
from ..beancount_types import (
|
||||||
ALL_DIRECTIVES,
|
ALL_DIRECTIVES,
|
||||||
Directive,
|
Directive,
|
||||||
Error,
|
|
||||||
)
|
)
|
||||||
from .core import (
|
from .core import (
|
||||||
Hook,
|
Hook,
|
||||||
HookName,
|
HookName,
|
||||||
)
|
)
|
||||||
|
from ..errors import (
|
||||||
|
Error,
|
||||||
|
)
|
||||||
|
|
||||||
__plugins__ = ['run']
|
__plugins__ = ['run']
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,10 @@ import abc
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from . import errors as errormod
|
from .. import errors as errormod
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -29,12 +30,9 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
from .._typing import (
|
from ..beancount_types import (
|
||||||
Account,
|
Account,
|
||||||
Directive,
|
Directive,
|
||||||
Error,
|
|
||||||
ErrorIter,
|
|
||||||
LessComparable,
|
|
||||||
MetaKey,
|
MetaKey,
|
||||||
MetaValue,
|
MetaValue,
|
||||||
MetaValueEnum,
|
MetaValueEnum,
|
||||||
|
@ -62,7 +60,7 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
|
||||||
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
HOOK_GROUPS: FrozenSet[HookName] = frozenset()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self, entry: Entry) -> ErrorIter: ...
|
def run(self, entry: Entry) -> errormod.Iter: ...
|
||||||
|
|
||||||
def __init_subclass__(cls):
|
def __init_subclass__(cls):
|
||||||
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
|
cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
|
||||||
|
@ -72,6 +70,14 @@ TransactionHook = Hook[Transaction]
|
||||||
|
|
||||||
### HELPER CLASSES
|
### HELPER CLASSES
|
||||||
|
|
||||||
|
class LessComparable(metaclass=abc.ABCMeta):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __le__(self, other: Any) -> bool: ...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __lt__(self, other: Any) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
CT = TypeVar('CT', bound=LessComparable)
|
CT = TypeVar('CT', bound=LessComparable)
|
||||||
class _GenericRange(Generic[CT]):
|
class _GenericRange(Generic[CT]):
|
||||||
"""Convenience class to check whether a value is within a range.
|
"""Convenience class to check whether a value is within a range.
|
||||||
|
@ -200,14 +206,14 @@ class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
|
||||||
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def run(self, txn: Transaction) -> ErrorIter:
|
def run(self, txn: Transaction) -> errormod.Iter:
|
||||||
if self._run_on_txn(txn):
|
if self._run_on_txn(txn):
|
||||||
for index, post in enumerate(txn.postings):
|
for index, post in enumerate(txn.postings):
|
||||||
if self._run_on_post(txn, post):
|
if self._run_on_post(txn, post):
|
||||||
yield from self.post_run(txn, post, index)
|
yield from self.post_run(txn, post, index)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
|
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: ...
|
||||||
|
|
||||||
|
|
||||||
class _NormalizePostingMetadataHook(_PostingHook):
|
class _NormalizePostingMetadataHook(_PostingHook):
|
||||||
|
@ -229,14 +235,14 @@ class _NormalizePostingMetadataHook(_PostingHook):
|
||||||
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
|
||||||
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
|
||||||
|
|
||||||
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
|
def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter:
|
||||||
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
source_value = self._meta_get(txn, post, self.METADATA_KEY)
|
||||||
set_value = source_value
|
set_value = source_value
|
||||||
error: Optional[Error] = None
|
error: Optional[errormod.Error] = None
|
||||||
if source_value is None:
|
if source_value is None:
|
||||||
try:
|
try:
|
||||||
set_value = self._default_value(txn, post)
|
set_value = self._default_value(txn, post)
|
||||||
except errormod._BaseError as error_:
|
except errormod.Error as error_:
|
||||||
error = error_
|
error = error_
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from . import core
|
from . import core
|
||||||
from .._typing import (
|
from ..beancount_types import (
|
||||||
MetaValueEnum,
|
MetaValueEnum,
|
||||||
Posting,
|
Posting,
|
||||||
Transaction,
|
Transaction,
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
import decimal
|
import decimal
|
||||||
|
|
||||||
from . import core
|
from . import core
|
||||||
from .._typing import (
|
from ..beancount_types import (
|
||||||
Posting,
|
Posting,
|
||||||
Transaction,
|
Transaction,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,14 +18,14 @@ import pytest
|
||||||
|
|
||||||
from . import testutil
|
from . import testutil
|
||||||
|
|
||||||
from conservancy_beancount import plugin, _typing
|
from conservancy_beancount import beancount_types, plugin
|
||||||
|
|
||||||
CONFIG_MAP = {}
|
CONFIG_MAP = {}
|
||||||
HOOK_REGISTRY = plugin.HookRegistry()
|
HOOK_REGISTRY = plugin.HookRegistry()
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
class TransactionCounter:
|
class TransactionCounter:
|
||||||
DIRECTIVE = _typing.Transaction
|
DIRECTIVE = beancount_types.Transaction
|
||||||
HOOK_GROUPS = frozenset()
|
HOOK_GROUPS = frozenset()
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
|
@ -34,7 +34,7 @@ class TransactionCounter:
|
||||||
|
|
||||||
@HOOK_REGISTRY.add_hook
|
@HOOK_REGISTRY.add_hook
|
||||||
class PostingCounter(TransactionCounter):
|
class PostingCounter(TransactionCounter):
|
||||||
DIRECTIVE = _typing.Transaction
|
DIRECTIVE = beancount_types.Transaction
|
||||||
HOOK_GROUPS = frozenset(['posting'])
|
HOOK_GROUPS = frozenset(['posting'])
|
||||||
|
|
||||||
def run(self, txn):
|
def run(self, txn):
|
||||||
|
|
Loading…
Reference in a new issue