query: Add database columns.

This commit is contained in:
Brett Smith 2021-03-12 16:11:17 -05:00
parent f0a5116429
commit 2e8e70cff3
2 changed files with 94 additions and 12 deletions

View file

@ -52,6 +52,8 @@ import enum
import functools
import itertools
import logging
import os
import sqlite3
import sys
from typing import (
@ -111,7 +113,7 @@ PROGNAME = 'query-report'
logger = logging.getLogger('conservancy_beancount.reports.query')
CellFunc = Callable[[Any], odf.table.TableCell]
EnvironmentFunctions = Dict[
EnvironmentColumns = Dict[
# The real key type is something like:
# Union[str, Tuple[str, Type, ...]]
# but two issues with that. One, you can't use Ellipses in a Tuple like
@ -119,8 +121,9 @@ EnvironmentFunctions = Dict[
# declare it anyway, and mypy infers it as Sequence[object]. So just use
# that.
Sequence[object],
Type[bc_query_compile.EvalFunction],
Type[bc_query_compile.EvalColumn],
]
EnvironmentFunctions = Dict[Sequence[object], Type[bc_query_compile.EvalFunction]]
RowTypes = Sequence[Tuple[str, Type]]
Rows = Sequence[NamedTuple]
RTResult = Optional[Mapping[Any, Any]]
@ -161,6 +164,65 @@ def ContextMeta(context: PostingContext) -> data.PostingMeta:
return data.PostingMeta(context.entry, sys.maxsize, context.posting).detached()
class DBColumn(bc_query_compile.EvalColumn):
_db_cursor: ClassVar[sqlite3.Cursor]
_db_query: ClassVar[str]
_dtype: ClassVar[Type] = set
__intypes__ = [Posting]
@classmethod
def with_db(cls, connection: sqlite3.Connection) -> Type['DBColumn']:
return type(cls.__name__, (cls,), {'_db_cursor': connection.cursor()})
def __init__(self, colname: Optional[str]=None) -> None:
if not hasattr(self, '_db_cursor'):
if colname is None:
colname = type(self).__name__.lower().replace('db', 'db_', 1)
raise RuntimeError(f"no entity database loaded - {colname} not available")
super().__init__(self._dtype)
def _entity(self, meta: data.PostingMeta) -> str:
entity = meta.get('entity')
return entity if isinstance(entity, str) else '\0'
def __call__(self, context: PostingContext) -> object:
entity = self._entity(ContextMeta(context))
self._db_cursor.execute(self._db_query, (entity,))
if issubclass(self._dtype, set):
return self._dtype(value for value, in self._db_cursor)
else:
row = self._db_cursor.fetchone()
return self._dtype() if row is None else self._dtype(row[0])
class DBEmail(DBColumn):
"""Look up an entity's email addresses from the database"""
_db_query = """
SELECT email.email_address
FROM donor
JOIN donor_email_address_mapping map ON donor.id = map.donor_id
JOIN email_address email ON map.email_address_id = email.id
WHERE donor.ledger_entity_id = ?
"""
class DBId(DBColumn):
"""Look up an entity's numeric id from the database"""
_db_query = "SELECT id FROM donor WHERE ledger_entity_id = ?"
_dtype = int
class DBPostal(DBColumn):
"""Look up an entity's postal addresses from the database"""
_db_query = """
SELECT postal.formatted_address
FROM donor
JOIN donor_postal_address_mapping map ON donor.id = map.donor_id
JOIN postal_address postal ON map.postal_address_id = postal.id
WHERE donor.ledger_entity_id = ?
"""
class MetaDocs(bc_query_env.AnyMeta):
"""Return a list of document links from metadata."""
def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None:
@ -343,31 +405,51 @@ class AggregateSet(bc_query_compile.EvalAggregator):
class _EnvironmentMixin:
db_path = Path('Financial', 'Ledger', 'supporters.db')
columns: EnvironmentColumns
functions: EnvironmentFunctions
@classmethod
def with_rt_client(
cls,
rt_client: Optional[rt.Rt],
cache_key: Hashable,
) -> Type['_EnvironmentMixin']:
def with_config(cls, config: configmod.Config) -> Type['_EnvironmentMixin']:
columns = cls.columns.copy()
repo_path = config.repository_path()
try:
if repo_path is None:
raise sqlite3.Error("no repository configured to host database")
db_conn = sqlite3.connect(os.fspath(repo_path / cls.db_path))
except (OSError, sqlite3.Error):
columns['db_email'] = DBEmail
columns['db_id'] = DBId
columns['db_postal'] = DBPostal
else:
columns['db_email'] = DBEmail.with_db(db_conn)
columns['db_id'] = DBId.with_db(db_conn)
columns['db_postal'] = DBPostal.with_db(db_conn)
rt_credentials = config.rt_credentials()
rt_client = config.rt_client(rt_credentials)
if rt_client is None:
rt_ticket = RTTicket
else:
rt_ticket = RTTicket.with_client(rt_client, cache_key)
rt_ticket = RTTicket.with_client(rt_client, rt_credentials.idstr())
functions = cls.functions.copy()
functions[('rt_ticket', str, str)] = rt_ticket
functions[('rt_ticket', str, str, int)] = rt_ticket
return type(cls.__name__, (cls,), {'functions': functions})
return type(cls.__name__, (cls,), {
'columns': columns,
'functions': functions,
})
class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment, _EnvironmentMixin):
columns: EnvironmentColumns # type:ignore[assignment]
functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment]
functions['meta_docs'] = MetaDocs
functions['str_meta'] = StrMeta
class TargetsEnvironment(bc_query_env.TargetsEnvironment, _EnvironmentMixin):
columns: EnvironmentColumns # type:ignore[assignment]
functions: EnvironmentFunctions = FilterPostingsEnvironment.functions.copy() # type:ignore[assignment]
functions.update(bc_query_env.AGGREGATOR_FUNCTIONS)
functions['set'] = AggregateSet
@ -424,8 +506,8 @@ class BQLShell(bc_query_shell.BQLShell):
rt_credentials = config.rt_credentials()
rt_key = rt_credentials.idstr()
rt_client = config.rt_client(rt_credentials)
self.env_postings = FilterPostingsEnvironment.with_rt_client(rt_client, rt_key)()
self.env_targets = TargetsEnvironment.with_rt_client(rt_client, rt_key)()
self.env_postings = FilterPostingsEnvironment.with_config(config)()
self.env_targets = TargetsEnvironment.with_config(config)()
self.ods = QueryODS(config.rt_wrapper(rt_credentials))
self.last_line_parsed = ''

View file

@ -5,7 +5,7 @@ from setuptools import setup
setup(
name='conservancy_beancount',
description="Plugin, library, and reports for reading Conservancy's books",
version='1.19.1',
version='1.19.2',
author='Software Freedom Conservancy',
author_email='info@sfconservancy.org',
license='GNU AGPLv3+',