diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index cba0102..7588106 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -58,14 +58,18 @@ from typing import ( cast, Any, Callable, + ClassVar, Dict, + Hashable, Iterable, Iterator, List, Mapping, + MutableMapping, NamedTuple, Optional, Sequence, + Set, TextIO, Tuple, Type, @@ -74,6 +78,9 @@ from typing import ( from ..beancount_types import ( MetaKey, MetaValue, + OptionsMap, + Posting, + Transaction, ) from decimal import Decimal @@ -90,6 +97,7 @@ import beancount.query.query_parser as bc_query_parser import beancount.query.query_render as bc_query_render import beancount.query.shell as bc_query_shell import odf.table # type:ignore[import] +import rt from . import core from . import rewrite @@ -100,6 +108,7 @@ from .. import data from .. import rtutil PROGNAME = 'query-report' +SENTINEL = object() logger = logging.getLogger('conservancy_beancount.reports.query') CellFunc = Callable[[Any], odf.table.TableCell] @@ -115,6 +124,7 @@ EnvironmentFunctions = Dict[ ] RowTypes = Sequence[Tuple[str, Type]] Rows = Sequence[NamedTuple] +RTResult = Optional[Mapping[Any, Any]] Store = List[Any] QueryExpression = Union[ bc_query_parser.Column, @@ -128,20 +138,30 @@ QueryStatement = Union[ bc_query_parser.Select, ] -# This class supports type checking. Beancount code dynamically sets the -# ``store`` attribute, in bc_query_execute.execute_query(). -class Context(bc_query_execute.RowContext): +# This class annotates the types that Beancount's RowContexts have when they're +# passed to EvalFunction.__call__(). These types get set across +# create_row_context and execute_query. +class PostingContext: + posting: Posting + entry: Transaction + balance: Inventory + options_map: OptionsMap + account_types: Mapping + open_close_map: Mapping + commodity_map: Mapping + price_map: Mapping + # Dynamically set by execute_query store: Store class MetaDocs(bc_query_env.AnyMeta): """Return a list of document links from metadata.""" - def __init__(self, operands: List[str]) -> None: + def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None: super(bc_query_env.AnyMeta, self).__init__(operands, list) # The second argument is our return type. # It should match the annotated return type of __call__. - def __call__(self, context: Context) -> List[str]: + def __call__(self, context: PostingContext) -> List[str]: raw_value = super().__call__(context) if isinstance(raw_value, str): return raw_value.split() @@ -149,12 +169,143 @@ class MetaDocs(bc_query_env.AnyMeta): return [] +class RTField(NamedTuple): + key: str + parse: Optional[Callable[[str], object]] + unset_value: Optional[str] = None + + def load(self, rt_ticket: RTResult) -> object: + value = rt_ticket.get(self.key) if rt_ticket else None + if not value or value == self.unset_value: + return None + elif self.parse is None: + return value + else: + return self.parse(value) + + +class RTTicket(bc_query_compile.EvalFunction): + """Look up a field from RT ticket(s) mentioned in metadata documentation""" + __intypes__ = [str, str, int] + FIELDS = {key: RTField(key, None) for key in [ + 'AdminCc', + 'Cc', + 'Creator', + 'Owner', + 'Queue', + 'Status', + 'Subject', + 'Requestors', + ]} + FIELDS.update((key, RTField(key, int, '0')) for key in [ + 'numerical_id', + 'FinalPriority', + 'InitialPriority', + 'Priority', + 'TimeEstimated', + 'TimeLeft', + 'TimeWorked', + ]) + FIELDS.update((key, RTField(key, rtutil.RTDateTime, 'Not set')) for key in [ + 'Created', + 'Due', + 'LastUpdated', + 'Resolved', + 'Started', + 'Starts', + 'Told', + ]) + FIELDS.update({key.lower(): value for key, value in FIELDS.items()}) + FIELDS['id'] = FIELDS['numerical_id'] + FIELDS['AdminCC'] = FIELDS['AdminCc'] + FIELDS['CC'] = FIELDS['Cc'] + RT_CLIENT: ClassVar[rt.Rt] + # _CACHES holds all of the caches for different RT instances that have + # been passed through RTTicket.with_client(). + _CACHES: ClassVar[Dict[Hashable, MutableMapping[str, RTResult]]] = {} + # _rt_cache is the cache specific to this RT_CLIENT. + _rt_cache: ClassVar[MutableMapping[str, RTResult]] = {} + + @classmethod + def with_client(cls, client: rt.Rt, cache_key: Hashable) -> Type['RTTicket']: + return type(cls.__name__, (cls,), { + 'RT_CLIENT': client, + '_rt_cache': cls._CACHES.setdefault(cache_key, {}), + }) + + def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None: + if not hasattr(self, 'RT_CLIENT'): + raise RuntimeError("no RT client available - cannot use rt_ticket()") + rt_op, meta_op, *rest = operands + # We have to evaluate the RT and meta keys on each call, because they + # might themselves be dynamic. In the common case they're constants. + # In that case, check for typos so we can report an error to the user + # before execution even begins. + if isinstance(rt_op, bc_query_compile.EvalConstant): + self._rt_key(rt_op.value) + if isinstance(meta_op, bc_query_compile.EvalConstant): + self._meta_key(meta_op.value) + if not rest: + operands.append(bc_query_compile.EvalConstant(sys.maxsize)) + super().__init__(operands, list) + + def _rt_key(self, key: str) -> RTField: + try: + return self.FIELDS[key] + except KeyError: + raise ValueError(f"unknown RT ticket field {key!r}") from None + + def _meta_key(self, key: str) -> str: + if key in data.LINK_METADATA: + return key + else: + raise ValueError(f"metadata key {key!r} does not contain documentation links") + + def __call__(self, context: PostingContext) -> list: + rt_key: str + meta_key: str + limit: int + rt_key, meta_key, limit = self.eval_args(context) + rt_field = self._rt_key(rt_key) + meta_key = self._meta_key(meta_key) + if context.posting.meta is None: + meta_value: Any = SENTINEL + else: + meta_value = context.posting.meta.get(meta_key, SENTINEL) + if meta_value is SENTINEL: + meta_value = context.entry.meta.get(meta_key) + if not isinstance(meta_value, str) or limit < 1: + meta_value = '' + ticket_ids: Set[str] = set() + for link_s in meta_value.split(): + rt_id = rtutil.RT.parse(link_s) + if rt_id is not None: + ticket_ids.add(rt_id[0]) + if len(ticket_ids) >= limit: + break + retval: List[object] = [] + for ticket_id in ticket_ids: + try: + rt_ticket = self._rt_cache[ticket_id] + except KeyError: + rt_ticket = self.RT_CLIENT.get_ticket(ticket_id) + self._rt_cache[ticket_id] = rt_ticket + field_value = rt_field.load(rt_ticket) + if field_value is None: + pass + elif isinstance(field_value, list): + retval.extend(field_value) + else: + retval.append(field_value) + return retval + + class StrMeta(bc_query_env.AnyMeta): """Looks up metadata like AnyMeta, then always returns a string.""" - def __init__(self, operands: List[str]) -> None: + def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None: super(bc_query_env.AnyMeta, self).__init__(operands, str) - def __call__(self, context: Context) -> str: + def __call__(self, context: PostingContext) -> str: raw_value = super().__call__(context) if raw_value is None: return '' @@ -166,7 +317,7 @@ class AggregateSet(bc_query_compile.EvalAggregator): """Filter argument values that aren't unique.""" __intypes__ = [object] - def __init__(self, operands: List[str]) -> None: + def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None: super().__init__(operands, set) def allocate(self, allocator: bc_query_execute.Allocator) -> None: @@ -179,7 +330,7 @@ class AggregateSet(bc_query_compile.EvalAggregator): # self.dtype() is our return type, aka the second argument to __init__ # above, aka the annotated return type of __call__. - def update(self, store: Store, context: Context) -> None: + def update(self, store: Store, context: PostingContext) -> None: """Update existing storage with new result data.""" value, = self.eval_args(context) if isinstance(value, Sequence) and not isinstance(value, (str, tuple)): @@ -187,19 +338,38 @@ class AggregateSet(bc_query_compile.EvalAggregator): else: store[self.handle].add(value) - def __call__(self, context: Context) -> set: + def __call__(self, context: PostingContext) -> set: """Return the result for an aggregation.""" return context.store[self.handle] # type:ignore[no-any-return] -class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment): +class _EnvironmentMixin: + functions: EnvironmentFunctions + + @classmethod + def with_rt_client( + cls, + rt_client: Optional[rt.Rt], + cache_key: Hashable, + ) -> Type['_EnvironmentMixin']: + if rt_client is None: + rt_ticket = RTTicket + else: + rt_ticket = RTTicket.with_client(rt_client, cache_key) + functions = cls.functions.copy() + functions[('rt_ticket', str, str)] = rt_ticket + functions[('rt_ticket', str, str, int)] = rt_ticket + return type(cls.__name__, (cls,), {'functions': functions}) + + +class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment, _EnvironmentMixin): functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment] functions['meta_docs'] = MetaDocs functions['str_meta'] = StrMeta -class TargetsEnvironment(bc_query_env.TargetsEnvironment): - functions = FilterPostingsEnvironment.functions.copy() +class TargetsEnvironment(bc_query_env.TargetsEnvironment, _EnvironmentMixin): + functions: EnvironmentFunctions = FilterPostingsEnvironment.functions.copy() # type:ignore[assignment] functions.update(bc_query_env.AGGREGATOR_FUNCTIONS) functions['set'] = AggregateSet @@ -244,17 +414,20 @@ class BooksLoader: class BQLShell(bc_query_shell.BQLShell): def __init__( self, + config: configmod.Config, is_interactive: bool, loadfun: Callable[[], books.LoadResult], outfile: TextIO, default_format: str='text', do_numberify: bool=False, - rt_wrapper: Optional[rtutil.RT]=None, ) -> None: super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify) - self.env_postings = FilterPostingsEnvironment() - self.env_targets = TargetsEnvironment() - self.ods = QueryODS(rt_wrapper) + rt_credentials = config.rt_credentials() + rt_key = rt_credentials.idstr() + rt_client = config.rt_client(rt_credentials) + self.env_postings = FilterPostingsEnvironment.with_rt_client(rt_client, rt_key)() + self.env_targets = TargetsEnvironment.with_rt_client(rt_client, rt_key)() + self.ods = QueryODS(config.rt_wrapper(rt_credentials)) self.last_line_parsed = '' def run_parser( @@ -621,12 +794,12 @@ def main(arglist: Optional[Sequence[str]]=None, [rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules], ) shell = BQLShell( + config, not query, load_func, stdout, args.report_type.value, args.numberify, - config.rt_wrapper(), ) shell.on_Reload() if query: diff --git a/setup.py b/setup.py index 5144736..ef7b126 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup setup( name='conservancy_beancount', description="Plugin, library, and reports for reading Conservancy's books", - version='1.19.0', + version='1.19.1', author='Software Freedom Conservancy', author_email='info@sfconservancy.org', license='GNU AGPLv3+', diff --git a/tests/test_reports_query.py b/tests/test_reports_query.py index 8f49478..fa5270b 100644 --- a/tests/test_reports_query.py +++ b/tests/test_reports_query.py @@ -21,6 +21,8 @@ import pytest from . import testutil from beancount.core import data as bc_data +from beancount.query import query_compile as bc_query_compile +from beancount.query import query_execute as bc_query_execute from beancount.query import query_parser as bc_query_parser from conservancy_beancount.books import FiscalYear from conservancy_beancount.reports import query as qmod @@ -28,6 +30,8 @@ from conservancy_beancount import rtutil from decimal import Decimal +UTC = datetime.timezone.utc + class MockRewriteRuleset: def __init__(self, multiplier=2): self.multiplier = multiplier @@ -39,6 +43,13 @@ class MockRewriteRuleset: yield post._replace(units=testutil.Amount(number, currency)) +class RowContext(bc_query_execute.RowContext): + def __init__(self, entry, posting=None): + super().__init__() + self.entry = entry + self.posting = posting + + @pytest.fixture(scope='module') def qparser(): return bc_query_parser.Parser() @@ -47,12 +58,128 @@ def qparser(): def rt(): return rtutil.RT(testutil.RTClient()) +@pytest.fixture(scope='module') +def ticket_query(): + return qmod.RTTicket.with_client(testutil.RTClient(), 'testfixture') + +def const_operands(*args): + return [bc_query_compile.EvalConstant(v) for v in args] + def pipe_main(arglist, config, stdout_type=io.StringIO): stdout = stdout_type() stderr = io.StringIO() returncode = qmod.main(arglist, stdout, stderr, config) return returncode, stdout, stderr +def test_rt_ticket_unconfigured(): + with pytest.raises(RuntimeError): + qmod.RTTicket(const_operands('id', 'rt-id')) + +@pytest.mark.parametrize('field_name', ['foo', 'bar']) +def test_rt_ticket_bad_field(ticket_query, field_name): + with pytest.raises(ValueError): + ticket_query(const_operands(field_name, 'rt-id')) + +@pytest.mark.parametrize('meta_name', ['foo', 'bar']) +def test_rt_ticket_bad_metadata(ticket_query, meta_name): + with pytest.raises(ValueError): + ticket_query(const_operands('id', meta_name)) + +@pytest.mark.parametrize('field_name,meta_name,expected', [ + ('id', 'rt-id', 1), + ('Queue', 'approval', 'general'), + ('Requestors', 'invoice', ['mx1@example.org', 'requestor2@example.org']), + ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)), +]) +def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected): + func = ticket_query(const_operands(field_name, meta_name)) + txn = testutil.Transaction(**{meta_name: 'rt:1'}, postings=[ + ('Assets:Cash', 80), + ]) + context = RowContext(txn, txn.postings[0]) + if not isinstance(expected, list): + expected = [expected] + assert func(context) == expected + +@pytest.mark.parametrize('field_name,meta_name,expected', [ + ('id', 'rt-id', 2), + ('Queue', 'approval', 'general'), + ('Requestors', 'invoice', ['mx2@example.org', 'requestor2@example.org']), + ('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)), +]) +def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected): + func = ticket_query(const_operands(field_name, meta_name)) + txn = testutil.Transaction(**{meta_name: 'rt:1'}, postings=[ + ('Assets:Cash', 110, {meta_name: 'rt:2/8'}), + ]) + context = RowContext(txn, txn.postings[0]) + if not isinstance(expected, list): + expected = [expected] + assert func(context) == expected + +@pytest.mark.parametrize('field_name,meta_name,expected,on_txn', [ + ('id', 'approval', [1, 2], True), + ('Queue', 'check', ['general', 'general'], False), + ('Requestors', 'invoice', [ + 'mx1@example.org', + 'mx2@example.org', + 'requestor2@example.org', + 'requestor2@example.org', + ], False), +]) +def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, on_txn): + func = ticket_query(const_operands(field_name, meta_name)) + txn = testutil.Transaction(**{'rt-id': 'rt:1'}, postings=[ + ('Assets:Cash', 110, {'rt-id': 'rt:2'}), + ]) + post = txn.postings[0] + meta = txn.meta if on_txn else post.meta + meta[meta_name] = 'rt:1/2 Docs/12.pdf rt:2/8' + context = RowContext(txn, post) + assert sorted(func(context)) == expected + +@pytest.mark.parametrize('meta_value,on_txn', testutil.combine_values( + ['', 'Docs/34.pdf', 'Docs/100.pdf Docs/120.pdf'], + [True, False], +)) +def test_rt_ticket_no_results(ticket_query, meta_value, on_txn): + func = ticket_query(const_operands('Queue', 'check')) + txn = testutil.Transaction(**{'rt-id': 'rt:1'}, postings=[ + ('Assets:Cash', 110, {'rt-id': 'rt:2'}), + ]) + post = txn.postings[0] + meta = txn.meta if on_txn else post.meta + meta['check'] = meta_value + context = RowContext(txn, post) + assert func(context) == [] + +def test_rt_ticket_caches_tickets(): + rt_client = testutil.RTClient() + rt_client.TICKET_DATA = testutil.RTClient.TICKET_DATA.copy() + ticket_query = qmod.RTTicket.with_client(rt_client, 'cachetestA') + func = ticket_query(const_operands('id', 'rt-id')) + txn = testutil.Transaction(postings=[ + ('Assets:Cash', 160, {'rt-id': 'rt:3'}), + ]) + context = RowContext(txn, txn.postings[0]) + assert func(context) == [3] + del rt_client.TICKET_DATA['3'] + assert func(context) == [3] + +def test_rt_ticket_caches_tickets_not_found(): + rt_client = testutil.RTClient() + rt_client.TICKET_DATA = testutil.RTClient.TICKET_DATA.copy() + rt3 = rt_client.TICKET_DATA.pop('3') + ticket_query = qmod.RTTicket.with_client(rt_client, 'cachetestB') + func = ticket_query(const_operands('id', 'rt-id')) + txn = testutil.Transaction(postings=[ + ('Assets:Cash', 160, {'rt-id': 'rt:3'}), + ]) + context = RowContext(txn, txn.postings[0]) + assert func(context) == [] + rt_client.TICKET_DATA['3'] = rt3 + assert func(context) == [] + def test_books_loader_empty(): result = qmod.BooksLoader(None)() assert not result.entries diff --git a/tests/testutil.py b/tests/testutil.py index 11a9439..9aca10e 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -25,6 +25,7 @@ from pathlib import Path from typing import Any, Optional, NamedTuple from conservancy_beancount import books, data, rtutil +from conservancy_beancount.config import RTCredentials EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30) FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99) @@ -286,10 +287,13 @@ class TestConfig: def repository_path(self): return self.repo_path - def rt_client(self): + def rt_credentials(self): + return RTCredentials('https://example.org/testrt', 'testuser', 'testpass') + + def rt_client(self, credentials=None): return self._rt_client - def rt_wrapper(self): + def rt_wrapper(self, credentials=None): return self._rt_wrapper @@ -417,9 +421,13 @@ class RTClient: ticket_id_s = str(ticket_id) if ticket_id_s not in self.TICKET_DATA: return None + ticket_id_n = int(ticket_id) retval = { 'id': 'ticket/{}'.format(ticket_id_s), 'numerical_id': ticket_id_s, + 'Created': f'2016-12-15T14:{ticket_id_n:02d}:40Z', + 'Due': f'2017-01-14T12:{ticket_id_n:02d}:00Z', + 'Queue': 'general', 'Requestors': [ f'mx{ticket_id_s}@example.org', 'requestor2@example.org',