data: Add part slicing methods to Account.
This commit is contained in:
parent
2b5cb0eca6
commit
8d3d7e7ce4
3 changed files with 150 additions and 1 deletions
|
@ -31,6 +31,7 @@ from beancount.core import position as bc_position
|
|||
|
||||
from typing import (
|
||||
cast,
|
||||
overload,
|
||||
Callable,
|
||||
Hashable,
|
||||
Iterable,
|
||||
|
@ -121,6 +122,64 @@ class Account(str):
|
|||
return prefix
|
||||
return None
|
||||
|
||||
def _find_part_slice(self, index: int) -> slice:
|
||||
if index < 0:
|
||||
raise ValueError(f"bad part index {index!r}")
|
||||
start = 0
|
||||
for _ in range(index):
|
||||
try:
|
||||
start = self.index(self.SEP, start) + 1
|
||||
except ValueError:
|
||||
raise IndexError("part index {index!r} out of range") from None
|
||||
try:
|
||||
stop = self.index(self.SEP, start + 1)
|
||||
except ValueError:
|
||||
stop = len(self)
|
||||
return slice(start, stop)
|
||||
|
||||
def count_parts(self) -> int:
|
||||
return self.count(self.SEP) + 1
|
||||
|
||||
@overload
|
||||
def slice_parts(self, start: None=None, stop: None=None) -> Sequence[str]: ...
|
||||
|
||||
@overload
|
||||
def slice_parts(self, start: slice, stop: None=None) -> Sequence[str]: ...
|
||||
|
||||
@overload
|
||||
def slice_parts(self, start: int, stop: int) -> Sequence[str]: ...
|
||||
|
||||
@overload
|
||||
def slice_parts(self, start: int, stop: None=None) -> str: ...
|
||||
|
||||
def slice_parts(self,
|
||||
start: Optional[Union[int, slice]]=None,
|
||||
stop: Optional[int]=None,
|
||||
) -> Sequence[str]:
|
||||
"""Slice the account parts like they were a list
|
||||
|
||||
Given a single index, return that part of the account name as a string.
|
||||
Otherwise, return a list of part names sliced according to the arguments.
|
||||
"""
|
||||
if start is None:
|
||||
part_slice = slice(None)
|
||||
elif isinstance(start, slice):
|
||||
part_slice = start
|
||||
elif stop is None:
|
||||
return self[self._find_part_slice(start)]
|
||||
else:
|
||||
part_slice = slice(start, stop)
|
||||
return self.split(self.SEP)[part_slice]
|
||||
|
||||
def root_part(self, count: int=1) -> str:
|
||||
"""Return the first part(s) of the account name as a string"""
|
||||
try:
|
||||
stop = self._find_part_slice(count - 1).stop
|
||||
except IndexError:
|
||||
return self
|
||||
else:
|
||||
return self[:stop]
|
||||
|
||||
|
||||
class Amount(bc_amount.Amount):
|
||||
"""Beancount amount after processing
|
||||
|
|
|
@ -385,7 +385,7 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]):
|
|||
self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds)
|
||||
self.age_balances = [core.MutableBalance() for _ in self.age_thresholds]
|
||||
accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1])
|
||||
acct_parts = key.split(':')
|
||||
acct_parts = key.slice_parts()
|
||||
self.use_sheet(acct_parts[1])
|
||||
self.add_row()
|
||||
self.add_row(self.string_cell(
|
||||
|
|
|
@ -115,3 +115,93 @@ def test_is_credit_card(acct_name, expected):
|
|||
])
|
||||
def test_is_opening_equity(acct_name, expected):
|
||||
assert data.Account(acct_name).is_opening_equity() == expected
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_slice_parts_no_args(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
assert account.slice_parts() == acct_name.split(':')
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_slice_parts_index(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
parts = acct_name.split(':')
|
||||
for index, expected in enumerate(parts):
|
||||
assert account.slice_parts(index) == expected
|
||||
with pytest.raises(IndexError):
|
||||
account.slice_parts(index + 1)
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_slice_parts_range(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
parts = acct_name.split(':')
|
||||
for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
|
||||
assert account.slice_parts(start, stop) == parts[start:stop]
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_slice_parts_slice(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
parts = acct_name.split(':')
|
||||
for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]):
|
||||
sl = slice(start, stop)
|
||||
assert account.slice_parts(sl) == parts[start:stop]
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_count_parts(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
assert account.count_parts() == acct_name.count(':') + 1
|
||||
|
||||
@pytest.mark.parametrize('acct_name', [
|
||||
'Assets:Cash',
|
||||
'Assets:Receivable:Accounts',
|
||||
'Expenses:Other',
|
||||
'Equity:Funds:Restricted',
|
||||
'Income:Other',
|
||||
'Liabilities:CreditCard',
|
||||
'Liabilities:Payable:Accounts',
|
||||
])
|
||||
def test_root_part(acct_name):
|
||||
account = data.Account(acct_name)
|
||||
parts = acct_name.split(':')
|
||||
assert account.root_part() == parts[0]
|
||||
assert account.root_part(1) == parts[0]
|
||||
assert account.root_part(2) == ':'.join(parts[:2])
|
||||
|
|
Loading…
Reference in a new issue