diff --git a/conservancy_beancount/data.py b/conservancy_beancount/data.py index b076977..b985ef2 100644 --- a/conservancy_beancount/data.py +++ b/conservancy_beancount/data.py @@ -31,6 +31,7 @@ from beancount.core import position as bc_position from typing import ( cast, + overload, Callable, Hashable, Iterable, @@ -121,6 +122,64 @@ class Account(str): return prefix return None + def _find_part_slice(self, index: int) -> slice: + if index < 0: + raise ValueError(f"bad part index {index!r}") + start = 0 + for _ in range(index): + try: + start = self.index(self.SEP, start) + 1 + except ValueError: + raise IndexError("part index {index!r} out of range") from None + try: + stop = self.index(self.SEP, start + 1) + except ValueError: + stop = len(self) + return slice(start, stop) + + def count_parts(self) -> int: + return self.count(self.SEP) + 1 + + @overload + def slice_parts(self, start: None=None, stop: None=None) -> Sequence[str]: ... + + @overload + def slice_parts(self, start: slice, stop: None=None) -> Sequence[str]: ... + + @overload + def slice_parts(self, start: int, stop: int) -> Sequence[str]: ... + + @overload + def slice_parts(self, start: int, stop: None=None) -> str: ... + + def slice_parts(self, + start: Optional[Union[int, slice]]=None, + stop: Optional[int]=None, + ) -> Sequence[str]: + """Slice the account parts like they were a list + + Given a single index, return that part of the account name as a string. + Otherwise, return a list of part names sliced according to the arguments. + """ + if start is None: + part_slice = slice(None) + elif isinstance(start, slice): + part_slice = start + elif stop is None: + return self[self._find_part_slice(start)] + else: + part_slice = slice(start, stop) + return self.split(self.SEP)[part_slice] + + def root_part(self, count: int=1) -> str: + """Return the first part(s) of the account name as a string""" + try: + stop = self._find_part_slice(count - 1).stop + except IndexError: + return self + else: + return self[:stop] + class Amount(bc_amount.Amount): """Beancount amount after processing diff --git a/conservancy_beancount/reports/accrual.py b/conservancy_beancount/reports/accrual.py index f8c2e9a..9f29488 100644 --- a/conservancy_beancount/reports/accrual.py +++ b/conservancy_beancount/reports/accrual.py @@ -385,7 +385,7 @@ class AgingODS(core.BaseODS[AccrualPostings, Optional[data.Account]]): self.age_thresholds = list(AccrualAccount.by_account(key).value.aging_thresholds) self.age_balances = [core.MutableBalance() for _ in self.age_thresholds] accrual_date = self.date - datetime.timedelta(days=self.age_thresholds[-1]) - acct_parts = key.split(':') + acct_parts = key.slice_parts() self.use_sheet(acct_parts[1]) self.add_row() self.add_row(self.string_cell( diff --git a/tests/test_data_account.py b/tests/test_data_account.py index c7a21b2..d67bd67 100644 --- a/tests/test_data_account.py +++ b/tests/test_data_account.py @@ -115,3 +115,93 @@ def test_is_credit_card(acct_name, expected): ]) def test_is_opening_equity(acct_name, expected): assert data.Account(acct_name).is_opening_equity() == expected + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_slice_parts_no_args(acct_name): + account = data.Account(acct_name) + assert account.slice_parts() == acct_name.split(':') + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_slice_parts_index(acct_name): + account = data.Account(acct_name) + parts = acct_name.split(':') + for index, expected in enumerate(parts): + assert account.slice_parts(index) == expected + with pytest.raises(IndexError): + account.slice_parts(index + 1) + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_slice_parts_range(acct_name): + account = data.Account(acct_name) + parts = acct_name.split(':') + for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]): + assert account.slice_parts(start, stop) == parts[start:stop] + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_slice_parts_slice(acct_name): + account = data.Account(acct_name) + parts = acct_name.split(':') + for start, stop in zip([0, 0, 1, 1], [2, 3, 2, 3]): + sl = slice(start, stop) + assert account.slice_parts(sl) == parts[start:stop] + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_count_parts(acct_name): + account = data.Account(acct_name) + assert account.count_parts() == acct_name.count(':') + 1 + +@pytest.mark.parametrize('acct_name', [ + 'Assets:Cash', + 'Assets:Receivable:Accounts', + 'Expenses:Other', + 'Equity:Funds:Restricted', + 'Income:Other', + 'Liabilities:CreditCard', + 'Liabilities:Payable:Accounts', +]) +def test_root_part(acct_name): + account = data.Account(acct_name) + parts = acct_name.split(':') + assert account.root_part() == parts[0] + assert account.root_part(1) == parts[0] + assert account.root_part(2) == ':'.join(parts[:2])