diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index 9d1dab9..ed6ea4b 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -169,12 +169,19 @@ class DBColumn(bc_query_compile.EvalColumn): _db_cursor: ClassVar[sqlite3.Cursor] _db_query: ClassVar[str] _dtype: ClassVar[Type] = set + _return: ClassVar[Callable[['DBColumn'], object]] __intypes__ = [Posting] @classmethod def with_db(cls, connection: sqlite3.Connection) -> Type['DBColumn']: return type(cls.__name__, (cls,), {'_db_cursor': connection.cursor()}) + def __init_subclass__(cls) -> None: + if issubclass(cls._dtype, set): + cls._return = cls._return_set + else: + cls._return = cls._return_scalar + def __init__(self, colname: Optional[str]=None) -> None: if not hasattr(self, '_db_cursor'): if colname is None: @@ -186,14 +193,17 @@ class DBColumn(bc_query_compile.EvalColumn): entity = meta.get('entity') return entity if isinstance(entity, str) else '\0' + def _return_scalar(self) -> object: + row = self._db_cursor.fetchone() + return self._dtype() if row is None else self._dtype(row[0]) + + def _return_set(self) -> object: + return self._dtype(value for value, in self._db_cursor) + def __call__(self, context: PostingContext) -> object: entity = self._entity(ContextMeta(context)) self._db_cursor.execute(self._db_query, (entity,)) - if issubclass(self._dtype, set): - return self._dtype(value for value, in self._db_cursor) - else: - row = self._db_cursor.fetchone() - return self._dtype() if row is None else self._dtype(row[0]) + return self._return() class DBEmail(DBColumn):