reports: Balance.__eq__ respects tolerance.
This commit is contained in:
parent
110e5038e1
commit
cd1766adcf
2 changed files with 56 additions and 4 deletions
|
@ -132,10 +132,12 @@ class Balance(Mapping[str, data.Amount]):
|
||||||
return type(self)(retval_map.values())
|
return type(self)(retval_map.values())
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if (self.is_zero()
|
if isinstance(other, Balance):
|
||||||
and isinstance(other, Balance)
|
clean_self = self.clean_copy()
|
||||||
and other.is_zero()):
|
clean_other = other.clean_copy()
|
||||||
return True
|
return len(clean_self) == len(clean_other) and all(
|
||||||
|
clean_self[key] == clean_other.get(key) for key in clean_self
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return super().__eq__(other)
|
return super().__eq__(other)
|
||||||
|
|
||||||
|
@ -160,6 +162,17 @@ class Balance(Mapping[str, data.Amount]):
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return all(op_func(amt.number, operand) for amt in self.values())
|
return all(op_func(amt.number, operand) for amt in self.values())
|
||||||
|
|
||||||
|
def copy(self: BalanceType) -> BalanceType:
|
||||||
|
return type(self)(self.values())
|
||||||
|
|
||||||
|
def clean_copy(self: BalanceType, tolerance: Optional[Decimal]=None) -> BalanceType:
|
||||||
|
if tolerance is None:
|
||||||
|
tolerance = self.tolerance
|
||||||
|
return type(self)(
|
||||||
|
amount for amount in self.values()
|
||||||
|
if abs(amount.number) >= tolerance
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
|
def within_tolerance(dec: DecimalCompat, tolerance: DecimalCompat) -> bool:
|
||||||
dec = cast(Decimal, dec)
|
dec = cast(Decimal, dec)
|
||||||
|
|
|
@ -34,6 +34,8 @@ DEFAULT_STRINGS = [
|
||||||
({'JPY': '-5500.00', 'BRL': '-8500.00'}, "-8,500.00 BRL, -5,500 JPY"),
|
({'JPY': '-5500.00', 'BRL': '-8500.00'}, "-8,500.00 BRL, -5,500 JPY"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TOLERANCES = [Decimal(n) for n in ['.1', '.01', '.001', 0]]
|
||||||
|
|
||||||
def amounts_from_map(currency_map):
|
def amounts_from_map(currency_map):
|
||||||
for code, number in currency_map.items():
|
for code, number in currency_map.items():
|
||||||
yield testutil.Amount(number, code)
|
yield testutil.Amount(number, code)
|
||||||
|
@ -219,6 +221,15 @@ def test_eq(map1, map2, expected):
|
||||||
actual = bal1 == bal2
|
actual = bal1 == bal2
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('tolerance', TOLERANCES)
|
||||||
|
def test_eq_considers_tolerance(tolerance):
|
||||||
|
tolerance = Decimal(tolerance)
|
||||||
|
mapping = {'EUR': 100, 'USD': '.002'}
|
||||||
|
bal1 = core.Balance(amounts_from_map(mapping))
|
||||||
|
mapping['USD'] = '.004'
|
||||||
|
bal2 = core.Balance(amounts_from_map(mapping), tolerance)
|
||||||
|
assert (bal1 == bal2) == (tolerance > Decimal('.002'))
|
||||||
|
|
||||||
@pytest.mark.parametrize('number,currency', {
|
@pytest.mark.parametrize('number,currency', {
|
||||||
(50, 'USD'),
|
(50, 'USD'),
|
||||||
(-50, 'USD'),
|
(-50, 'USD'),
|
||||||
|
@ -294,6 +305,34 @@ def test_iadd_balance(mapping):
|
||||||
expected = core.Balance(amounts_from_map(expect_numbers))
|
expected = core.Balance(amounts_from_map(expect_numbers))
|
||||||
assert balance == expected
|
assert balance == expected
|
||||||
|
|
||||||
|
def test_copy():
|
||||||
|
amounts = frozenset(amounts_from_map({'USD': 10, 'EUR': '.001'}))
|
||||||
|
# Use a ridiculous tolerance to test it doesn't matter.
|
||||||
|
actual = core.Balance(amounts, 100).copy()
|
||||||
|
assert frozenset(actual.values()) == amounts
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('tolerance', TOLERANCES)
|
||||||
|
def test_clean_copy(tolerance):
|
||||||
|
usd = testutil.Amount(10)
|
||||||
|
eur = testutil.Amount('.002', 'EUR')
|
||||||
|
actual = core.Balance([usd, eur], tolerance).clean_copy()
|
||||||
|
if tolerance < eur.number:
|
||||||
|
expected = {usd, eur}
|
||||||
|
else:
|
||||||
|
expected = {usd}
|
||||||
|
assert frozenset(actual.values()) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('tolerance', TOLERANCES)
|
||||||
|
def test_clean_copy_arg(tolerance):
|
||||||
|
usd = testutil.Amount(10)
|
||||||
|
eur = testutil.Amount('.002', 'EUR')
|
||||||
|
actual = core.Balance([usd, eur], 0).clean_copy(tolerance)
|
||||||
|
if tolerance < eur.number:
|
||||||
|
expected = {usd, eur}
|
||||||
|
else:
|
||||||
|
expected = {usd}
|
||||||
|
assert frozenset(actual.values()) == expected
|
||||||
|
|
||||||
@pytest.mark.parametrize('mapping,expected', DEFAULT_STRINGS)
|
@pytest.mark.parametrize('mapping,expected', DEFAULT_STRINGS)
|
||||||
def test_str(mapping, expected):
|
def test_str(mapping, expected):
|
||||||
balance = core.Balance(amounts_from_map(mapping))
|
balance = core.Balance(amounts_from_map(mapping))
|
||||||
|
|
Loading…
Reference in a new issue