plugin: Refactor hooks to use new payment-related methods.

This commit is contained in:
Brett Smith 2020-03-29 10:30:54 -04:00
parent 5f85d9c747
commit 30d371278a
3 changed files with 17 additions and 11 deletions

View file

@ -15,6 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from . import core
from .. import config as configmod
from .. import data
from .. import errors as errormod
from ..beancount_types import (
@ -23,14 +24,16 @@ from ..beancount_types import (
)
class MetaReceipt(core._RequireLinksPostingMetadataHook):
DEFAULT_STOP_AMOUNT = 0
METADATA_KEY = 'receipt'
def __init__(self, config: configmod.Config) -> None:
self.payment_threshold = -abs(config.payment_threshold())
def _run_on_post(self, txn: Transaction, post: data.Posting) -> bool:
return bool(
(post.account.is_real_asset() or post.account.is_under('Liabilities'))
and post.units.number
and post.units.number < self.DEFAULT_STOP_AMOUNT
and post.units.number is not None
and post.units.number < self.payment_threshold
)
def post_run(self, txn: Transaction, post: data.Posting) -> errormod.Iter:

View file

@ -17,13 +17,12 @@
import decimal
from . import core
from .. import config as configmod
from .. import data
from ..beancount_types import (
Transaction,
)
DEFAULT_STOP_AMOUNT = decimal.Decimal(0)
class MetaTaxImplication(core._NormalizePostingMetadataHook):
VALUES_ENUM = core.MetadataEnum('tax-implication', [
'1099',
@ -45,9 +44,8 @@ class MetaTaxImplication(core._NormalizePostingMetadataHook):
'W2',
])
def __init__(self, config: configmod.Config) -> None:
self.payment_threshold = config.payment_threshold()
def _run_on_post(self, txn: Transaction, post: data.Posting) -> bool:
return bool(
post.account.is_real_asset()
and post.units.number
and post.units.number < DEFAULT_STOP_AMOUNT
)
return post.is_payment(self.payment_threshold)

View file

@ -144,10 +144,12 @@ class Transaction:
class TestConfig:
def __init__(self,
def __init__(self, *,
payment_threshold=0,
repo_path=None,
rt_client=None,
):
self._payment_threshold = Decimal(payment_threshold)
self.repo_path = test_path(repo_path)
self._rt_client = rt_client
if rt_client is None:
@ -155,6 +157,9 @@ class TestConfig:
else:
self._rt_wrapper = rtutil.RT(rt_client)
def payment_threshold(self):
return self._payment_threshold
def repository_path(self):
return self.repo_path