ranges: Start module.

The ledger report wants to use this functionality, so make it available in a
higher-level module.

I took the opportunity to clean up a lot of the surrounding type
declarations. It is less flexible, since it relies on the static list of
types in RangeT, but I don't think the other method actually worked at all
except by cheating with generic Any.
This commit is contained in:
Brett Smith 2020-06-15 09:14:42 -04:00
parent 760e0a8cd9
commit a2ee9c73fe
3 changed files with 127 additions and 35 deletions

View file

@ -21,6 +21,7 @@ import re
from .. import config as configmod from .. import config as configmod
from .. import data from .. import data
from .. import errors as errormod from .. import errors as errormod
from .. import ranges
from typing import ( from typing import (
Any, Any,
@ -49,10 +50,10 @@ from ..beancount_types import (
# I expect these will become configurable in the future, which is why I'm # I expect these will become configurable in the future, which is why I'm
# keeping them outside of a class, but for now constants will do. # keeping them outside of a class, but for now constants will do.
DEFAULT_START_DATE: datetime.date = datetime.date(2020, 3, 1) DEFAULT_START_DATE = datetime.date(2020, 3, 1)
# The default stop date leaves a little room after so it's easy to test # The default stop date leaves a little room after so it's easy to test
# dates past the far end of the range. # dates past the far end of the range.
DEFAULT_STOP_DATE: datetime.date = datetime.date(datetime.MAXYEAR, 1, 1) DEFAULT_STOP_DATE = datetime.date(datetime.MAXYEAR, 1, 1)
### TYPE DEFINITIONS ### TYPE DEFINITIONS
@ -74,38 +75,6 @@ class Hook(Generic[Entry], metaclass=abc.ABCMeta):
### HELPER CLASSES ### HELPER CLASSES
class LessComparable(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __le__(self, other: Any) -> bool: ...
@abc.abstractmethod
def __lt__(self, other: Any) -> bool: ...
CT = TypeVar('CT', bound=LessComparable)
class _GenericRange(Generic[CT]):
"""Convenience class to check whether a value is within a range.
`foo in generic_range` is equivalent to `start <= foo < stop`.
Since we have multiple user-configurable ranges, having the check
encapsulated in an object helps implement the check consistently, and
makes it easier for subclasses to override.
"""
def __init__(self, start: CT, stop: CT) -> None:
self.start = start
self.stop = stop
def __repr__(self) -> str:
return "{clsname}({self.start!r}, {self.stop!r})".format(
clsname=type(self).__name__,
self=self,
)
def __contains__(self, item: CT) -> bool:
return self.start <= item < self.stop
class MetadataEnum: class MetadataEnum:
"""Map acceptable metadata values to their normalized forms. """Map acceptable metadata values to their normalized forms.
@ -178,7 +147,7 @@ class MetadataEnum:
class TransactionHook(Hook[Transaction]): class TransactionHook(Hook[Transaction]):
DIRECTIVE = Transaction DIRECTIVE = Transaction
SKIP_FLAGS: Container[str] = frozenset() SKIP_FLAGS: Container[str] = frozenset()
TXN_DATE_RANGE: _GenericRange = _GenericRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE) TXN_DATE_RANGE = ranges.DateRange(DEFAULT_START_DATE, DEFAULT_STOP_DATE)
def _run_on_txn(self, txn: Transaction) -> bool: def _run_on_txn(self, txn: Transaction) -> bool:
"""Check whether we should run on a given transaction """Check whether we should run on a given transaction

View file

@ -0,0 +1,59 @@
"""ranges.py - Higher-typed range classes"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import datetime
from decimal import Decimal
from typing import (
Generic,
TypeVar,
Union,
)
RangeT = TypeVar(
'RangeT',
# This is a relatively arbitrary set of types. Feel free to add to it if
# you need; the types just need to support enough comparisons to implement
# _GenericRange.__contains__.
datetime.date,
datetime.datetime,
datetime.time,
Union[int, Decimal],
)
class _GenericRange(Generic[RangeT]):
"""range for higher-level types
This class knows how to check membership for higher-level types just like
Python's built-in range. It does not know how to iterate or step.
"""
def __init__(self, start: RangeT, stop: RangeT) -> None:
self.start: RangeT = start
self.stop: RangeT = stop
def __repr__(self) -> str:
return "{clsname}({self.start!r}, {self.stop!r})".format(
clsname=type(self).__name__,
self=self,
)
def __contains__(self, item: RangeT) -> bool:
return self.start <= item < self.stop
DateRange = _GenericRange[datetime.date]
DecimalCompatRange = _GenericRange[Union[int, Decimal]]

64
tests/test_ranges.py Normal file
View file

@ -0,0 +1,64 @@
"""test_ranges.py - Unit tests for range classes"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import datetime
import pytest
from conservancy_beancount import ranges
ONE_DAY = datetime.timedelta(days=1)
@pytest.mark.parametrize('start,stop', [
# One month
(datetime.date(2018, 3, 1), datetime.date(2018, 4, 1)),
# Three months
(datetime.date(2018, 6, 1), datetime.date(2018, 9, 1)),
# Six months, spanning year
(datetime.date(2018, 9, 1), datetime.date(2019, 3, 1)),
# Nine months
(datetime.date(2018, 2, 1), datetime.date(2018, 12, 1)),
# Twelve months on Jan 1
(datetime.date(2018, 1, 1), datetime.date(2019, 1, 1)),
# Twelve months spanning year
(datetime.date(2018, 3, 1), datetime.date(2019, 3, 1)),
# Eighteen months spanning year
(datetime.date(2018, 3, 1), datetime.date(2019, 9, 1)),
# Wild
(datetime.date(2018, 1, 1), datetime.date(2020, 4, 15)),
])
def test_date_range(start, stop):
date_range = ranges.DateRange(start, stop)
assert (start - ONE_DAY) not in date_range
assert start in date_range
assert (start + ONE_DAY) in date_range
assert (stop - ONE_DAY) in date_range
assert stop not in date_range
assert (stop + ONE_DAY) not in date_range
def test_date_range_one_day():
start = datetime.date(2018, 7, 1)
date_range = ranges.DateRange(start, start + ONE_DAY)
assert (start - ONE_DAY) not in date_range
assert start in date_range
assert (start + ONE_DAY) not in date_range
def test_date_range_empty():
date = datetime.date(2018, 8, 10)
date_range = ranges.DateRange(date, date)
assert (date - ONE_DAY) not in date_range
assert date not in date_range
assert (date + ONE_DAY) not in date_range