data: Add part slicing methods to Account.

This commit is contained in:
Brett Smith 2020-06-06 16:38:53 -04:00
parent 2b5cb0eca6
commit 8d3d7e7ce4
3 changed files with 150 additions and 1 deletions

View file

@ -31,6 +31,7 @@ from beancount.core import position as bc_position
from typing import (
cast,
overload,
Callable,
Hashable,
Iterable,
@ -121,6 +122,64 @@ class Account(str):
return prefix
return None
def _find_part_slice(self, index: int) -> slice:
if index < 0:
raise ValueError(f"bad part index {index!r}")
start = 0
for _ in range(index):
try:
start = self.index(self.SEP, start) + 1
except ValueError:
raise IndexError("part index {index!r} out of range") from None
try:
stop = self.index(self.SEP, start + 1)
except ValueError:
stop = len(self)
return slice(start, stop)
def count_parts(self) -> int:
return self.count(self.SEP) + 1
@overload
def slice_parts(self, start: None=None, stop: None=None) -> Sequence[str]: ...
@overload
def slice_parts(self, start: slice, stop: None=None) -> Sequence[str]: ...
@overload
def slice_parts(self, start: int, stop: int) -> Sequence[str]: ...
@overload
def slice_parts(self, start: int, stop: None=None) -> str: ...
def slice_parts(self,
start: Optional[Union[int, slice]]=None,
stop: Optional[int]=None,
) -> Sequence[str]:
"""Slice the account parts like they were a list
Given a single index, return that part of the account name as a string.
Otherwise, return a list of part names sliced according to the arguments.
"""
if start is None:
part_slice = slice(None)
elif isinstance(start, slice):
part_slice = start
elif stop is None:
return self[self._find_part_slice(start)]
else:
part_slice = slice(start, stop)
return self.split(self.SEP)[part_slice]
def root_part(self, count: int=1) -> str:
"""Return the first part(s) of the account name as a string"""
try:
stop = self._find_part_slice(count - 1).stop
except IndexError:
return self
else:
return self[:stop]
class Amount(bc_amount.Amount):
"""Beancount amount after processing

View file

@ -385,7 +385,7 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds)
self.age_balances = [core.MutableBalance() for _ in self.age_thresholds]
accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1])
acct_parts = key.split(':')
acct_parts = key.slice_parts()
self.use_sheet(acct_parts[1])
self.add_row()
self.add_row(self.string_cell(

View file

@ -115,3 +115,93 @@ def test_is_credit_card(acct_name, expected):
])
def test_is_opening_equity(acct_name, expected):
assert data.Account(acct_name).is_opening_equity() == expected
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_slice_parts_no_args(acct_name):
account = data.Account(acct_name)
assert account.slice_parts() == acct_name.split(':')
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_slice_parts_index(acct_name):
account = data.Account(acct_name)
parts = acct_name.split(':')
for index, expected in enumerate(parts):
assert account.slice_parts(index) == expected
with pytest.raises(IndexError):
account.slice_parts(index + 1)
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_slice_parts_range(acct_name):
account = data.Account(acct_name)
parts = acct_name.split(':')
for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
assert account.slice_parts(start, stop) == parts[start:stop]
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_slice_parts_slice(acct_name):
account = data.Account(acct_name)
parts = acct_name.split(':')
for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
sl = slice(start, stop)
assert account.slice_parts(sl) == parts[start:stop]
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_count_parts(acct_name):
account = data.Account(acct_name)
assert account.count_parts() == acct_name.count(':') + 1
@pytest.mark.parametrize('acct_name', [
'Assets:Cash',
'Assets:Receivable:Accounts',
'Expenses:Other',
'Equity:Funds:Restricted',
'Income:Other',
'Liabilities:CreditCard',
'Liabilities:Payable:Accounts',
])
def test_root_part(acct_name):
account = data.Account(acct_name)
parts = acct_name.split(':')
assert account.root_part() == parts[0]
assert account.root_part(1) == parts[0]
assert account.root_part(2) == ':'.join(parts[:2])