query: Refactor DBColumn.

Avoid an issubclass check on every call, and make it easier for subclasses
to override part of the call implementation.
This commit is contained in:
Brett Smith 2021-03-15 13:35:56 -04:00
parent 6703d1af87
commit b880115774

View file

@ -169,12 +169,19 @@ class DBColumn(bc_query_compile.EvalColumn):
_db_cursor: ClassVar[sqlite3.Cursor] _db_cursor: ClassVar[sqlite3.Cursor]
_db_query: ClassVar[str] _db_query: ClassVar[str]
_dtype: ClassVar[Type] = set _dtype: ClassVar[Type] = set
_return: ClassVar[Callable[['DBColumn'], object]]
__intypes__ = [Posting] __intypes__ = [Posting]
@classmethod @classmethod
def with_db(cls, connection: sqlite3.Connection) -> Type['DBColumn']: def with_db(cls, connection: sqlite3.Connection) -> Type['DBColumn']:
return type(cls.__name__, (cls,), {'_db_cursor': connection.cursor()}) return type(cls.__name__, (cls,), {'_db_cursor': connection.cursor()})
def __init_subclass__(cls) -> None:
if issubclass(cls._dtype, set):
cls._return = cls._return_set
else:
cls._return = cls._return_scalar
def __init__(self, colname: Optional[str]=None) -> None: def __init__(self, colname: Optional[str]=None) -> None:
if not hasattr(self, '_db_cursor'): if not hasattr(self, '_db_cursor'):
if colname is None: if colname is None:
@ -186,14 +193,17 @@ class DBColumn(bc_query_compile.EvalColumn):
entity = meta.get('entity') entity = meta.get('entity')
return entity if isinstance(entity, str) else '\0' return entity if isinstance(entity, str) else '\0'
def _return_scalar(self) -> object:
row = self._db_cursor.fetchone()
return self._dtype() if row is None else self._dtype(row[0])
def _return_set(self) -> object:
return self._dtype(value for value, in self._db_cursor)
def __call__(self, context: PostingContext) -> object: def __call__(self, context: PostingContext) -> object:
entity = self._entity(ContextMeta(context)) entity = self._entity(ContextMeta(context))
self._db_cursor.execute(self._db_query, (entity,)) self._db_cursor.execute(self._db_query, (entity,))
if issubclass(self._dtype, set): return self._return()
return self._dtype(value for value, in self._db_cursor)
else:
row = self._db_cursor.fetchone()
return self._dtype() if row is None else self._dtype(row[0])
class DBEmail(DBColumn): class DBEmail(DBColumn):