diff --git a/conservancy_beancount/data.py b/conservancy_beancount/data.py index ddb6b8c..7437165 100644 --- a/conservancy_beancount/data.py +++ b/conservancy_beancount/data.py @@ -408,6 +408,12 @@ class Posting(BasePosting): except AttributeError: pass + def at_cost(self) -> Amount: + if self.cost is None: + return self.units + else: + return Amount(self.units.number * self.cost.number, self.cost.currency) + _KT = TypeVar('_KT', bound=Hashable) _VT = TypeVar('_VT') diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index 77f5e64..a00e38d 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -356,22 +356,10 @@ class RelatedPostings(Sequence[data.Posting]): yield post, balance def balance(self) -> Balance: - for _, balance in self.iter_with_balance(): - pass - try: - return balance - except NameError: - return Balance() + return Balance(post.units for post in self) def balance_at_cost(self) -> Balance: - balance = MutableBalance() - for post in self: - if post.cost is None: - balance += post.units - else: - number = post.units.number * post.cost.number - balance += data.Amount(number, post.cost.currency) - return balance + return Balance(post.at_cost() for post in self) def meta_values(self, key: MetaKey, diff --git a/tests/test_data_posting.py b/tests/test_data_posting.py index 9bd3347..a53b181 100644 --- a/tests/test_data_posting.py +++ b/tests/test_data_posting.py @@ -82,3 +82,27 @@ def test_from_entries_mix_txns_and_other_directives(simple_txn): assert all(source[x] == post[x] for x in range(len(source) - 1)) assert isinstance(post.account, data.Account) assert post.meta['note'] # Only works with PostingMeta + +@pytest.mark.parametrize('cost_num', [105, 110, 115]) +def test_at_cost(cost_num): + post = data.Posting( + 'Income:Donations', + testutil.Amount(25, 'EUR'), + testutil.Cost(cost_num, 'JPY'), + None, + '*', + None, + ) + assert post.at_cost() == testutil.Amount(25 * cost_num, 'JPY') + +def test_at_cost_no_cost(): + amount = testutil.Amount(25, 'EUR') + post = data.Posting( + 'Income:Donations', + amount, + None, + None, + '*', + None, + ) + assert post.at_cost() == amount