accrual: Move more functionality into AccrualPostings.

This commit is contained in:
Brett Smith 2020-06-02 14:11:01 -04:00
parent 677c99b565
commit 58b02b6f33

View file

@ -115,12 +115,12 @@ class Sentinel:
class Account(NamedTuple): class Account(NamedTuple):
name: str name: str
balance_paid: Callable[[core.Balance], bool] norm_func: Callable[[core.Balance], core.Balance]
class AccrualAccount(enum.Enum): class AccrualAccount(enum.Enum):
PAYABLE = Account('Liabilities:Payable', core.Balance.ge_zero) PAYABLE = Account('Liabilities:Payable', operator.neg)
RECEIVABLE = Account('Assets:Receivable', core.Balance.le_zero) RECEIVABLE = Account('Assets:Receivable', lambda bal: bal)
@classmethod @classmethod
def account_names(cls) -> Iterator[str]: def account_names(cls) -> Iterator[str]:
@ -134,14 +134,6 @@ class AccrualAccount(enum.Enum):
return account return account
raise ValueError("unrecognized account set in related postings") raise ValueError("unrecognized account set in related postings")
@classmethod
def filter_paid_accruals(cls, groups: PostGroups) -> PostGroups:
return {
key: related
for key, related in groups.items()
if not cls.classify(related).value.balance_paid(related.balance())
}
class AccrualPostings(core.RelatedPostings): class AccrualPostings(core.RelatedPostings):
def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc] def _meta_getter(key: MetaKey) -> Callable[[data.Posting], MetaValue]: # type:ignore[misc]
@ -160,6 +152,7 @@ class AccrualPostings(core.RelatedPostings):
INCONSISTENT = Sentinel() INCONSISTENT = Sentinel()
__slots__ = ( __slots__ = (
'accrual_type', 'accrual_type',
'final_bal',
'account', 'account',
'accounts', 'accounts',
'contract', 'contract',
@ -198,8 +191,10 @@ class AccrualPostings(core.RelatedPostings):
self.entities = self.entitys self.entities = self.entitys
if self.account is self.INCONSISTENT: if self.account is self.INCONSISTENT:
self.accrual_type: Optional[AccrualAccount] = None self.accrual_type: Optional[AccrualAccount] = None
self.final_bal = self.balance()
else: else:
self.accrual_type = AccrualAccount.classify(self) self.accrual_type = AccrualAccount.classify(self)
self.final_bal = self.accrual_type.value.norm_func(self.balance())
def make_consistent(self) -> Iterator[Tuple[MetaValue, 'AccrualPostings']]: def make_consistent(self) -> Iterator[Tuple[MetaValue, 'AccrualPostings']]:
account_ok = isinstance(self.account, str) account_ok = isinstance(self.account, str)
@ -232,21 +227,33 @@ class AccrualPostings(core.RelatedPostings):
) )
yield Error(post.meta, errmsg, post.meta.txn) yield Error(post.meta, errmsg, post.meta.txn)
def is_paid(self, default: Optional[bool]=None) -> Optional[bool]:
if self.accrual_type is None:
return default
else:
return self.final_bal.le_zero()
class BaseReport: def is_zero(self, default: Optional[bool]=None) -> Optional[bool]:
def __init__(self, out_file: TextIO) -> None: if self.accrual_type is None:
self.out_file = out_file return default
self.logger = logger.getChild(type(self).__name__) else:
return self.final_bal.is_zero()
def _since_last_nonzero(self, posts: AccrualPostings) -> AccrualPostings: def since_last_nonzero(self) -> 'AccrualPostings':
for index, (post, balance) in enumerate(posts.iter_with_balance()): for index, (post, balance) in enumerate(self.iter_with_balance()):
if balance.is_zero(): if balance.is_zero():
start_index = index start_index = index
try: try:
empty = start_index == index empty = start_index == index
except NameError: except NameError:
empty = True empty = True
return posts if empty else AccrualPostings(posts[start_index + 1:]) return self if empty else self[start_index + 1:]
class BaseReport:
def __init__(self, out_file: TextIO) -> None:
self.out_file = out_file
self.logger = logger.getChild(type(self).__name__)
def _report(self, def _report(self,
invoice: str, invoice: str,
@ -267,7 +274,7 @@ class BalanceReport(BaseReport):
posts: AccrualPostings, posts: AccrualPostings,
index: int, index: int,
) -> Iterable[str]: ) -> Iterable[str]:
posts = self._since_last_nonzero(posts) posts = posts.since_last_nonzero()
balance = posts.balance() balance = posts.balance()
date_s = posts[0].meta.date.strftime('%Y-%m-%d') date_s = posts[0].meta.date.strftime('%Y-%m-%d')
if index: if index:
@ -298,7 +305,7 @@ class OutgoingReport(BaseReport):
posts: AccrualPostings, posts: AccrualPostings,
index: int, index: int,
) -> Iterable[str]: ) -> Iterable[str]:
posts = self._since_last_nonzero(posts) posts = posts.since_last_nonzero()
try: try:
ticket_id, _ = self._primary_rt_id(posts) ticket_id, _ = self._primary_rt_id(posts)
ticket = self.rt_client.get_ticket(ticket_id) ticket = self.rt_client.get_ticket(ticket_id)
@ -329,13 +336,12 @@ class OutgoingReport(BaseReport):
) )
requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip() requestor = f'{requestor_name} <{rt_requestor["EmailAddress"]}>'.strip()
raw_balance = -posts.balance()
cost_balance = -posts.balance_at_cost() cost_balance = -posts.balance_at_cost()
cost_balance_s = cost_balance.format(None) cost_balance_s = cost_balance.format(None)
if raw_balance == cost_balance: if posts.final_bal == cost_balance:
balance_s = cost_balance_s balance_s = cost_balance_s
else: else:
balance_s = f'{raw_balance} ({cost_balance_s})' balance_s = f'{posts.final_bal} ({cost_balance_s})'
contract_links = posts.all_meta_links('contract') contract_links = posts.all_meta_links('contract')
if contract_links: if contract_links:
@ -382,8 +388,8 @@ class ReportType(enum.Enum):
@classmethod @classmethod
def default_for(cls, groups: PostGroups) -> 'ReportType': def default_for(cls, groups: PostGroups) -> 'ReportType':
if len(groups) == 1 and all( if len(groups) == 1 and all(
AccrualAccount.classify(group) is AccrualAccount.PAYABLE group.accrual_type is AccrualAccount.PAYABLE
and not AccrualAccount.PAYABLE.value.balance_paid(group.balance()) and not group.is_paid()
for group in groups.values() for group in groups.values()
): ):
return cls.OUTGOING return cls.OUTGOING
@ -501,7 +507,7 @@ def main(arglist: Optional[Sequence[str]]=None,
filters.remove_opening_balance_txn(entries) filters.remove_opening_balance_txn(entries)
postings = filter_search(data.Posting.from_entries(entries), args.search_terms) postings = filter_search(data.Posting.from_entries(entries), args.search_terms)
groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice')) groups: PostGroups = dict(AccrualPostings.group_by_meta(postings, 'invoice'))
groups = AccrualAccount.filter_paid_accruals(groups) or groups groups = {key: group for key, group in groups.items() if not group.is_paid()} or groups
returncode = 0 returncode = 0
for error in load_errors: for error in load_errors:
bc_printer.print_error(error, file=stderr) bc_printer.print_error(error, file=stderr)