a2ee9c73fe
The ledger report wants to use this functionality, so make it available in a higher-level module. I took the opportunity to clean up a lot of the surrounding type declarations. It is less flexible, since it relies on the static list of types in RangeT, but I don't think the other method actually worked at all except by cheating with generic Any.
59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
"""ranges.py - Higher-typed range classes"""
|
|
# Copyright © 2020 Brett Smith
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
import datetime
|
|
|
|
from decimal import Decimal
|
|
|
|
from typing import (
|
|
Generic,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
RangeT = TypeVar(
|
|
'RangeT',
|
|
# This is a relatively arbitrary set of types. Feel free to add to it if
|
|
# you need; the types just need to support enough comparisons to implement
|
|
# _GenericRange.__contains__.
|
|
datetime.date,
|
|
datetime.datetime,
|
|
datetime.time,
|
|
Union[int, Decimal],
|
|
)
|
|
|
|
class _GenericRange(Generic[RangeT]):
|
|
"""range for higher-level types
|
|
|
|
This class knows how to check membership for higher-level types just like
|
|
Python's built-in range. It does not know how to iterate or step.
|
|
"""
|
|
def __init__(self, start: RangeT, stop: RangeT) -> None:
|
|
self.start: RangeT = start
|
|
self.stop: RangeT = stop
|
|
|
|
def __repr__(self) -> str:
|
|
return "{clsname}({self.start!r}, {self.stop!r})".format(
|
|
clsname=type(self).__name__,
|
|
self=self,
|
|
)
|
|
|
|
def __contains__(self, item: RangeT) -> bool:
|
|
return self.start <= item < self.stop
|
|
|
|
|
|
DateRange = _GenericRange[datetime.date]
|
|
DecimalCompatRange = _GenericRange[Union[int, Decimal]]
|