Improve organization between modules.
* Rename _typing to beancount_types to better reflect what it is. * LessComparable isn't a Beancount type, so move that to plugin.core with its dependent helper classes. * Errors are a core Beancount concept, so move that module to the top level and have it include appropriate type definitions.
This commit is contained in:
		
							parent
							
								
									a41feb94b3
								
							
						
					
					
						commit
						3fbc14d377
					
				
					 7 changed files with 35 additions and 32 deletions
				
			
		| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
"""Type definitions for Conservancy Beancount code"""
 | 
					"""Type definitions for Beancount data structures"""
 | 
				
			||||||
# Copyright © 2020  Brett Smith
 | 
					# Copyright © 2020  Brett Smith
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# This program is free software: you can redistribute it and/or modify
 | 
					# This program is free software: you can redistribute it and/or modify
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,6 @@ import abc
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import beancount.core.data as bc_data
 | 
					import beancount.core.data as bc_data
 | 
				
			||||||
from .plugin import errors
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    Any,
 | 
					    Any,
 | 
				
			||||||
| 
						 | 
					@ -33,21 +32,11 @@ from typing import (
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Account = bc_data.Account
 | 
					Account = bc_data.Account
 | 
				
			||||||
Error = errors._BaseError
 | 
					 | 
				
			||||||
ErrorIter = Iterable[Error]
 | 
					 | 
				
			||||||
MetaKey = str
 | 
					MetaKey = str
 | 
				
			||||||
MetaValue = Any
 | 
					MetaValue = Any
 | 
				
			||||||
MetaValueEnum = str
 | 
					MetaValueEnum = str
 | 
				
			||||||
Posting = bc_data.Posting
 | 
					Posting = bc_data.Posting
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LessComparable(metaclass=abc.ABCMeta):
 | 
					 | 
				
			||||||
    @abc.abstractmethod
 | 
					 | 
				
			||||||
    def __le__(self, other: Any) -> bool: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @abc.abstractmethod
 | 
					 | 
				
			||||||
    def __lt__(self, other: Any) -> bool: ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Directive(NamedTuple):
 | 
					class Directive(NamedTuple):
 | 
				
			||||||
    meta: bc_data.Meta
 | 
					    meta: bc_data.Meta
 | 
				
			||||||
    date: datetime.date
 | 
					    date: datetime.date
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,11 @@
 | 
				
			||||||
# You should have received a copy of the GNU Affero General Public License
 | 
					# You should have received a copy of the GNU Affero General Public License
 | 
				
			||||||
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | 
					# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _BaseError(Exception):
 | 
					from typing import (
 | 
				
			||||||
 | 
					    Iterable,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Error(Exception):
 | 
				
			||||||
    def __init__(self, message, entry, source=None):
 | 
					    def __init__(self, message, entry, source=None):
 | 
				
			||||||
        self.message = message
 | 
					        self.message = message
 | 
				
			||||||
        self.entry = entry
 | 
					        self.entry = entry
 | 
				
			||||||
| 
						 | 
					@ -28,7 +32,9 @@ class _BaseError(Exception):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InvalidMetadataError(_BaseError):
 | 
					Iter = Iterable[Error]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InvalidMetadataError(Error):
 | 
				
			||||||
    def __init__(self, txn, post, key, value=None, source=None):
 | 
					    def __init__(self, txn, post, key, value=None, source=None):
 | 
				
			||||||
        if value is None:
 | 
					        if value is None:
 | 
				
			||||||
            msg_fmt = "{post.account} missing {key}"
 | 
					            msg_fmt = "{post.account} missing {key}"
 | 
				
			||||||
| 
						 | 
					@ -28,15 +28,17 @@ from typing import (
 | 
				
			||||||
    Tuple,
 | 
					    Tuple,
 | 
				
			||||||
    Type,
 | 
					    Type,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from .._typing import (
 | 
					from ..beancount_types import (
 | 
				
			||||||
    ALL_DIRECTIVES,
 | 
					    ALL_DIRECTIVES,
 | 
				
			||||||
    Directive,
 | 
					    Directive,
 | 
				
			||||||
    Error,
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from .core import (
 | 
					from .core import (
 | 
				
			||||||
    Hook,
 | 
					    Hook,
 | 
				
			||||||
    HookName,
 | 
					    HookName,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from ..errors import (
 | 
				
			||||||
 | 
					    Error,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__plugins__ = ['run']
 | 
					__plugins__ = ['run']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,9 +18,10 @@ import abc
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import errors as errormod
 | 
					from .. import errors as errormod
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
 | 
					    Any,
 | 
				
			||||||
    FrozenSet,
 | 
					    FrozenSet,
 | 
				
			||||||
    Generic,
 | 
					    Generic,
 | 
				
			||||||
    Iterable,
 | 
					    Iterable,
 | 
				
			||||||
| 
						 | 
					@ -29,12 +30,9 @@ from typing import (
 | 
				
			||||||
    Optional,
 | 
					    Optional,
 | 
				
			||||||
    TypeVar,
 | 
					    TypeVar,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from .._typing import (
 | 
					from ..beancount_types import (
 | 
				
			||||||
    Account,
 | 
					    Account,
 | 
				
			||||||
    Directive,
 | 
					    Directive,
 | 
				
			||||||
    Error,
 | 
					 | 
				
			||||||
    ErrorIter,
 | 
					 | 
				
			||||||
    LessComparable,
 | 
					 | 
				
			||||||
    MetaKey,
 | 
					    MetaKey,
 | 
				
			||||||
    MetaValue,
 | 
					    MetaValue,
 | 
				
			||||||
    MetaValueEnum,
 | 
					    MetaValueEnum,
 | 
				
			||||||
| 
						 | 
					@ -62,7 +60,7 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
 | 
				
			||||||
    HOOK_GROUPS: FrozenSet[HookName] = frozenset()
 | 
					    HOOK_GROUPS: FrozenSet[HookName] = frozenset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def run(self, entry: Entry) -> ErrorIter: ...
 | 
					    def run(self, entry: Entry) -> errormod.Iter: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init_subclass__(cls):
 | 
					    def __init_subclass__(cls):
 | 
				
			||||||
        cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
 | 
					        cls.DIRECTIVE = cls.__orig_bases__[0].__args__[0]
 | 
				
			||||||
| 
						 | 
					@ -72,6 +70,14 @@ TransactionHook = Hook[Transaction]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### HELPER CLASSES
 | 
					### HELPER CLASSES
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LessComparable(metaclass=abc.ABCMeta):
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def __le__(self, other: Any) -> bool: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def __lt__(self, other: Any) -> bool: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CT = TypeVar('CT', bound=LessComparable)
 | 
					CT = TypeVar('CT', bound=LessComparable)
 | 
				
			||||||
class _GenericRange(Generic[CT]):
 | 
					class _GenericRange(Generic[CT]):
 | 
				
			||||||
    """Convenience class to check whether a value is within a range.
 | 
					    """Convenience class to check whether a value is within a range.
 | 
				
			||||||
| 
						 | 
					@ -200,14 +206,14 @@ class _PostingHook(TransactionHook, metaclass=abc.ABCMeta):
 | 
				
			||||||
    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
 | 
					    def _run_on_post(self, txn: Transaction, post: Posting) -> bool:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn: Transaction) -> ErrorIter:
 | 
					    def run(self, txn: Transaction) -> errormod.Iter:
 | 
				
			||||||
        if self._run_on_txn(txn):
 | 
					        if self._run_on_txn(txn):
 | 
				
			||||||
            for index, post in enumerate(txn.postings):
 | 
					            for index, post in enumerate(txn.postings):
 | 
				
			||||||
                if self._run_on_post(txn, post):
 | 
					                if self._run_on_post(txn, post):
 | 
				
			||||||
                    yield from self.post_run(txn, post, index)
 | 
					                    yield from self.post_run(txn, post, index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @abc.abstractmethod
 | 
					    @abc.abstractmethod
 | 
				
			||||||
    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter: ...
 | 
					    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _NormalizePostingMetadataHook(_PostingHook):
 | 
					class _NormalizePostingMetadataHook(_PostingHook):
 | 
				
			||||||
| 
						 | 
					@ -229,14 +235,14 @@ class _NormalizePostingMetadataHook(_PostingHook):
 | 
				
			||||||
    def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
 | 
					    def _default_value(self, txn: Transaction, post: Posting) -> MetaValueEnum:
 | 
				
			||||||
        raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
 | 
					        raise errormod.InvalidMetadataError(txn, post, self.METADATA_KEY)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> ErrorIter:
 | 
					    def post_run(self, txn: Transaction, post: Posting, post_index: int) -> errormod.Iter:
 | 
				
			||||||
        source_value = self._meta_get(txn, post, self.METADATA_KEY)
 | 
					        source_value = self._meta_get(txn, post, self.METADATA_KEY)
 | 
				
			||||||
        set_value = source_value
 | 
					        set_value = source_value
 | 
				
			||||||
        error: Optional[Error] = None
 | 
					        error: Optional[errormod.Error] = None
 | 
				
			||||||
        if source_value is None:
 | 
					        if source_value is None:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                set_value = self._default_value(txn, post)
 | 
					                set_value = self._default_value(txn, post)
 | 
				
			||||||
            except errormod._BaseError as error_:
 | 
					            except errormod.Error as error_:
 | 
				
			||||||
                error = error_
 | 
					                error = error_
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,7 +15,7 @@
 | 
				
			||||||
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | 
					# along with this program.  If not, see <https://www.gnu.org/licenses/>.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import core
 | 
					from . import core
 | 
				
			||||||
from .._typing import (
 | 
					from ..beancount_types import (
 | 
				
			||||||
    MetaValueEnum,
 | 
					    MetaValueEnum,
 | 
				
			||||||
    Posting,
 | 
					    Posting,
 | 
				
			||||||
    Transaction,
 | 
					    Transaction,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@
 | 
				
			||||||
import decimal
 | 
					import decimal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import core
 | 
					from . import core
 | 
				
			||||||
from .._typing import (
 | 
					from ..beancount_types import (
 | 
				
			||||||
    Posting,
 | 
					    Posting,
 | 
				
			||||||
    Transaction,
 | 
					    Transaction,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,14 +18,14 @@ import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import testutil
 | 
					from . import testutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from conservancy_beancount import plugin, _typing
 | 
					from conservancy_beancount import beancount_types, plugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CONFIG_MAP = {}
 | 
					CONFIG_MAP = {}
 | 
				
			||||||
HOOK_REGISTRY = plugin.HookRegistry()
 | 
					HOOK_REGISTRY = plugin.HookRegistry()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@HOOK_REGISTRY.add_hook
 | 
					@HOOK_REGISTRY.add_hook
 | 
				
			||||||
class TransactionCounter:
 | 
					class TransactionCounter:
 | 
				
			||||||
    DIRECTIVE = _typing.Transaction
 | 
					    DIRECTIVE = beancount_types.Transaction
 | 
				
			||||||
    HOOK_GROUPS = frozenset()
 | 
					    HOOK_GROUPS = frozenset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn):
 | 
					    def run(self, txn):
 | 
				
			||||||
| 
						 | 
					@ -34,7 +34,7 @@ class TransactionCounter:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@HOOK_REGISTRY.add_hook
 | 
					@HOOK_REGISTRY.add_hook
 | 
				
			||||||
class PostingCounter(TransactionCounter):
 | 
					class PostingCounter(TransactionCounter):
 | 
				
			||||||
    DIRECTIVE = _typing.Transaction
 | 
					    DIRECTIVE = beancount_types.Transaction
 | 
				
			||||||
    HOOK_GROUPS = frozenset(['posting'])
 | 
					    HOOK_GROUPS = frozenset(['posting'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, txn):
 | 
					    def run(self, txn):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue