From 8af45e5f8a5b72c4aea209ed19a0e00f9864e705 Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Mon, 8 Mar 2021 13:48:25 -0500 Subject: [PATCH] query: Add BQL functions for dealing with link metadata. query-report was heading to a place where it was going to bifurcate. You could structure input with its own special input switches, and with ODS output, it would have its own dedicated grouping logic and use that. But those things shouldn't be tied together for users. Instead, add functions to BQL to be able to do the kind of grouping we want. This commit adds those. Next we'll extend the ODS output to detect and format these groups correctly. --- conservancy_beancount/reports/query.py | 228 ++++++++++++------------- tests/test_reports_query.py | 145 +--------------- 2 files changed, 115 insertions(+), 258 deletions(-) diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index d6b0062..77a5222 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -22,6 +22,7 @@ from typing import ( Dict, Iterable, Iterator, + List, Mapping, NamedTuple, Optional, @@ -63,12 +64,27 @@ BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain( bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type] )) PROGNAME = 'query-report' -QUERY_PARSER = bc_query_parser.Parser() logger = logging.getLogger('conservancy_beancount.reports.query') CellFunc = Callable[[Any], odf.table.TableCell] +EnvironmentFunctions = Dict[ + # The real key type is something like: + # Union[str, Tuple[str, Type, ...]] + # but two issues with that. One, you can't use Ellipses in a Tuple like + # that, so there's no short way to declare this. Second, Beancount doesn't + # declare it anyway, and mypy infers it as Sequence[object]. So just use + # that. + Sequence[object], + Type[bc_query_compile.EvalFunction], +] RowTypes = Sequence[Tuple[str, Type]] Rows = Sequence[NamedTuple] +Store = List[Any] +QueryStatement = Union[ + bc_query_parser.Balances, + bc_query_parser.Journal, + bc_query_parser.Select, +] class BooksLoader: """Closure to load books with a zero-argument callable @@ -176,6 +192,73 @@ class QueryODS(core.BaseODS[NamedTuple, None]): )) +# This class mostly supports type checking. Beancount code dynamically sets the +# ``store`` attribute, in bc_query_execute.execute_query(). +class Context(bc_query_execute.RowContext): + store: Store + + +class MetaDocs(bc_query_env.AnyMeta): + """Return a list of document links from metadata.""" + def __init__(self, operands: List[str]) -> None: + super(bc_query_env.AnyMeta, self).__init__(operands, list) + + def __call__(self, context: Context) -> List[str]: + raw_value = super().__call__(context) + if isinstance(raw_value, str): + return raw_value.split() + else: + return [] + + +class StrMeta(bc_query_env.AnyMeta): + """Looks up metadata like AnyMeta, then always returns a string.""" + def __init__(self, operands: List[str]) -> None: + super(bc_query_env.AnyMeta, self).__init__(operands, str) + + def __call__(self, context: Context) -> str: + raw_value = super().__call__(context) + if raw_value is None: + return '' + else: + return str(raw_value) + + +class AggregateSet(bc_query_compile.EvalAggregator): + __intypes__ = [object] + + def __init__(self, operands: List[str]) -> None: + super().__init__(operands, set) + + def allocate(self, allocator: bc_query_execute.Allocator) -> None: + self.handle = allocator.allocate() + + def initialize(self, store: Store) -> None: + store[self.handle] = self.dtype() + + def update(self, store: Store, context: Context) -> None: + value, = self.eval_args(context) + if isinstance(value, Sequence) and not isinstance(value, str): + store[self.handle].update(value) + else: + store[self.handle].add(value) + + def __call__(self, context: Context) -> set: + return context.store[self.handle] # type:ignore[no-any-return] + + +class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment): + functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment] + functions['meta_docs'] = MetaDocs + functions['str_meta'] = StrMeta + + +class TargetsEnvironment(bc_query_env.TargetsEnvironment): + functions = FilterPostingsEnvironment.functions.copy() + functions.update(bc_query_env.AGGREGATOR_FUNCTIONS) + functions['set'] = AggregateSet + + class BQLShell(bc_query_shell.BQLShell): def __init__( self, @@ -187,9 +270,11 @@ class BQLShell(bc_query_shell.BQLShell): rt_wrapper: Optional[rtutil.RT]=None, ) -> None: super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify) + self.env_postings = FilterPostingsEnvironment() + self.env_targets = TargetsEnvironment() self.ods = QueryODS(rt_wrapper) - def on_Select(self, statement: str) -> None: + def on_Select(self, statement: QueryStatement) -> None: output_format: str = self.vars['format'] try: render_func = getattr(self, f'_render_{output_format}') @@ -206,7 +291,7 @@ class BQLShell(bc_query_shell.BQLShell): row_types, rows = bc_query_execute.execute_query( compiled_query, self.entries, self.options_map, ) - if self.vars['numberify'] and output_format != 'ods': + if self.vars['numberify']: logger.debug("numberifying query") row_types, rows = bc_query_numberify.numberify_results( row_types, rows, self.options_map['dcontext'].build(), @@ -251,14 +336,6 @@ class BQLShell(bc_query_shell.BQLShell): ) -class JoinOperator(enum.Enum): - AND = 'AND' - OR = 'OR' - - def join(self, parts: Iterable[str]) -> str: - return f' {self.value} '.join(parts) - - class ReportFormat(enum.Enum): TEXT = 'text' TXT = TEXT @@ -266,54 +343,6 @@ class ReportFormat(enum.Enum): ODS = 'ods' -def _date_condition( - date: Union[int, datetime.date], - year_to_date: Callable[[int], datetime.date], - op: str, -) -> str: - if isinstance(date, int): - date = year_to_date(date) - return f'date {op} {date.isoformat()}' - -def build_query( - args: argparse.Namespace, - fy: books.FiscalYear, - in_file: Optional[TextIO]=None, -) -> Optional[str]: - if not args.query: - args.query = [] if in_file is None else [line[:-1] for line in in_file] - plain_query = ' '.join(args.query) - if not plain_query or plain_query.isspace(): - return None - try: - QUERY_PARSER.parse(plain_query) - except bc_query_parser.ParseError: - if args.join is None: - args.join = JoinOperator.AND - select = [ - 'date', - 'ANY_META("entity") AS entity', - 'narration AS description', - 'COST(position) AS booked_amount', - *(f'ANY_META("{field}") AS {field.replace("-", "_")}' - if field not in BUILTIN_FIELDS - and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field) - else field - for field in args.select), - ] - conds = [f'({args.join.join(args.query)})'] - if args.start_date is not None: - conds.append(_date_condition(args.start_date, fy.first_date, '>=')) - if args.stop_date is not None: - conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<')) - return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}' - else: - if args.join: - raise ValueError("cannot specify --join with a full query") - if args.select: - raise ValueError("cannot specify --select with a full query") - return plain_query - def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: parser = argparse.ArgumentParser(prog=PROGNAME) cliutil.add_version_argument(parser) @@ -321,20 +350,18 @@ def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace parser.add_argument( '--begin', '--start', '-b', dest='start_date', - metavar='DATE', + metavar='YEAR', type=cliutil.year_or_date_arg, - help="""Begin loading entries from this fiscal year. When query-report -builds the query, it will include a condition `date >= DATE`. + help="""Begin loading entries from this fiscal year. You can specify a +full date, and %(prog)s will use the fiscal year for that date. """) parser.add_argument( '--end', '--stop', '-e', dest='stop_date', - metavar='DATE', + metavar='YEAR', type=cliutil.year_or_date_arg, - help="""End loading entries from this fiscal year. When query-report -builds the query, it will include a condition `date < DATE`. If you specify a -begin date but not an end date, the default end date will be the end of the -fiscal year of the begin date. + help="""End loading entries at this fiscal year. You can specify a +full date, and %(prog)s will use the fiscal year for that date. """) cliutil.add_rewrite_rules_argument(parser) format_arg = cliutil.EnumArgument(ReportFormat) @@ -344,7 +371,8 @@ fiscal year of the begin date. type=format_arg.enum_type, help="""Format of report to generate. Choices are {format_arg.choices_str()}. Default is guessed from your output filename -extension, or 'text' if that fails. +extension. If that fails, default is 'text' for interactive output, and 'ods' +otherwise. """) parser.add_argument( '--numberify', '-m', @@ -359,44 +387,12 @@ extension, or 'text' if that fails. The default is stdout for text and CSV reports, and a generated filename for ODS reports. """) - - query_group = parser.add_argument_group("query options", """ -You can write a single full query as a command line argument (like bean-query), -or you can write individual WHERE condition(s) as arguments. If you write -WHERE conditions, these options are used to build the rest of the query. -""") - join_arg = cliutil.EnumArgument(JoinOperator) - query_group.add_argument( - '--select', '-s', - metavar='COLUMN', - default=[], - action=cliutil.ExtendAction, - help="""Columns to select. You can write these as comma-separated -names, and/or specify the option more than once. You can specify both -bean-query's built-in column names (like `account` and `flag`) and metadata -keys. -""") - query_group.add_argument( - '--group-by', '-g', - metavar='COLUMN', - help="""Group output by this column -""") - # query_group.add_argument( - # '--order-by', '--sort', '-r', - # metavar='COLUMN', - # help="""Order output by this column - # """), - query_group.add_argument( - '--join', '-j', - metavar='OP', - type=join_arg.enum_type, - help=f"""Join your WHERE conditions with this operator. -Choices are {join_arg.choices_str()}. Default 'and'. -"""), - query_group.add_argument( + parser.add_argument( 'query', nargs=argparse.ZERO_OR_MORE, - help="""Full query or WHERE conditions to run non-interactively + default=[], + help="""Query to run non-interactively. If none is provided, and +standard input is not a terminal, reads the query from stdin instead. """) args = parser.parse_args(arglist) @@ -413,20 +409,14 @@ def main(arglist: Optional[Sequence[str]]=None, config = configmod.Config() config.load_file() - fy = config.fiscal_year_begin() - if args.stop_date is None and args.start_date is not None: - args.stop_date = fy.next_fy_date(args.start_date) - try: - query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin) - except ValueError as error: - logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG)) - return 2 - is_interactive = query is None and sys.stdin.isatty() + query = ' '.join(args.query).strip() + if not query and not sys.stdin.isatty(): + query = sys.stdin.read().strip() if args.report_type is None: try: args.report_type = ReportFormat[args.output_file.suffix[1:].upper()] except (AttributeError, KeyError): - args.report_type = ReportFormat.TEXT if is_interactive else ReportFormat.ODS + args.report_type = ReportFormat.ODS if query else ReportFormat.TEXT load_func = BooksLoader( config.books_loader(), @@ -435,7 +425,7 @@ def main(arglist: Optional[Sequence[str]]=None, [rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules], ) shell = BQLShell( - is_interactive, + not query, load_func, stdout, args.report_type.value, @@ -443,10 +433,10 @@ def main(arglist: Optional[Sequence[str]]=None, config.rt_wrapper(), ) shell.on_Reload() - if query is None: - shell.cmdloop() - else: + if query: shell.onecmd(query) + else: + shell.cmdloop() if not shell.ods.is_empty(): shell.ods.set_common_properties(config.books_repo()) @@ -454,7 +444,7 @@ def main(arglist: Optional[Sequence[str]]=None, if args.output_file is None: out_dir_path = config.repository_path() or Path() args.output_file = out_dir_path / 'QueryResults_{}.ods'.format( - datetime.datetime.now().isoformat(timespec='minutes'), + datetime.datetime.now().isoformat(timespec='seconds'), ) logger.info("Writing spreadsheet to %s", args.output_file) ods_file = cliutil.bytes_output(args.output_file, stdout) diff --git a/tests/test_reports_query.py b/tests/test_reports_query.py index ed2ba64..492069b 100644 --- a/tests/test_reports_query.py +++ b/tests/test_reports_query.py @@ -38,10 +38,6 @@ class MockRewriteRuleset: yield post._replace(units=testutil.Amount(number, currency)) -@pytest.fixture(scope='module') -def fy(): - return FiscalYear(3, 1) - @pytest.fixture(scope='module') def rt(): return rtutil.RT(testutil.RTClient()) @@ -52,15 +48,6 @@ def pipe_main(arglist, config, stdout_type=io.StringIO): returncode = qmod.main(arglist, stdout, stderr, config) return returncode, stdout, stderr -def query_args(query=None, start_date=None, stop_date=None, join=None, select=None): - if isinstance(join, str): - join = qmod.JoinOperator[join] - if select is None: - select = [] - elif isinstance(select, str): - select = select.split(',') - return argparse.Namespace(**locals()) - def test_books_loader_empty(): result = qmod.BooksLoader(None)() assert not result.entries @@ -90,130 +77,6 @@ def test_books_loader_rewrites(): assert numbers assert all(abs(number) >= 40 for number in numbers) -@pytest.mark.parametrize('file_s', [None, '', ' \n \n\n']) -def test_build_query_empty(fy, file_s): - args = query_args() - if file_s is None: - query = qmod.build_query(args, fy) - else: - with io.StringIO(file_s) as qfile: - query = qmod.build_query(args, fy, qfile) - assert query is None - -@pytest.mark.parametrize('query_str', [ - 'SELECT * WHERE date >= 2018-03-01', - 'select *', - 'JOURNAL "Income:Donations"', - 'journal', - 'BALANCES FROM year=2018', - 'balances', -]) -def test_build_query_in_arglist(fy, query_str): - args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE) - assert qmod.build_query(args, fy) == query_str - -@pytest.mark.parametrize('argname,argval', [ - ('join', qmod.JoinOperator.AND), - ('join', qmod.JoinOperator.OR), - ('select', 'date,flag'), - ('select', 'position,rt-id'), -]) -def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval): - args = query_args(['journal'], **{argname: argval}) - with pytest.raises(ValueError): - qmod.build_query(args, fy) - -@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1)) -def test_build_query_where_arglist_conditions(fy, count, join_op): - conds = ['account ~ "^Income:"', 'year >= 2018'][:count] - args = query_args(conds, join=join_op.name) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - assert query[cond_index:] == '({})'.format(join_op.join(conds)) - -@pytest.mark.parametrize('select', [ - ['flag'], - ['check'], - ['flag', 'month'], - ['cost_label', 'cost_metakey'], - ['approval', 'receipt'], -]) -def test_build_query_select_fields(fy, select): - args = query_args(['year>2018'], select=list(select)) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - start_index = 7 - for field in select: - if field != 'flag' and field != 'month' and field != 'cost_label': - field = f'ANY_META("{field}") AS {field.replace("-", "_")}' - match = re.search(rf',\s*{re.escape(field)}\b', query) - assert match, f"field {field!r} not found in query: {query!r}" - assert match.start() >= start_index - start_index = match.end() - assert query[start_index:start_index + 7] == ' WHERE ' - -@pytest.mark.parametrize('argname,date_arg', itertools.product( - ['start_date', 'stop_date'], - [testutil.FY_START_DATE, testutil.FY_START_DATE.year], -)) -def test_build_query_one_date_arg(fy, argname, date_arg): - query_kwargs = { - argname: date_arg, - 'query': ['flag = "*"', 'flag = "!"'], - 'join': 'OR', - } - args = query_args(**query_kwargs) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - if argname == 'start_date': - expect_op = '>=' - year_to_date = fy.first_date - else: - expect_op = '<' - year_to_date = fy.next_fy_date - if not isinstance(date_arg, datetime.date): - date_arg = year_to_date(date_arg) - assert query[cond_index:] == '({}) AND date {} {}'.format( - ' OR '.join(query_kwargs['query']), expect_op, date_arg.isoformat(), - ) - -@pytest.mark.parametrize('start_date,stop_date', itertools.product( - [testutil.PAST_DATE, testutil.PAST_DATE.year], - [testutil.FUTURE_DATE, testutil.FUTURE_DATE.year], -)) -def test_build_query_two_date_args(fy, start_date, stop_date): - args = query_args(['account ~ "^Equity:"'], start_date, stop_date, 'AND') - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - if isinstance(start_date, int): - start_date = fy.first_date(start_date) - if isinstance(stop_date, int): - stop_date = fy.next_fy_date(stop_date) - assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format( - args.query[0], start_date.isoformat(), stop_date.isoformat(), - ) - -def test_build_query_plain_from_file(fy): - with io.StringIO("SELECT *\n WHERE account ~ '^Assets:';\n") as qfile: - query = qmod.build_query(query_args(), fy, qfile) - assert re.fullmatch(r"SELECT \*\s+WHERE account ~ '\^Assets:';\s*", query) - -def test_build_query_from_file_where_clauses(fy): - conds = ["account ~ '^Income:'", "account ~ '^Expenses:'"] - args = query_args(None, testutil.PAST_DATE, testutil.FUTURE_DATE, 'OR') - with io.StringIO(''.join(f'{s}\n' for s in conds)) as qfile: - query = qmod.build_query(args, fy, qfile) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format( - ' OR '.join(conds), - testutil.PAST_DATE.isoformat(), - testutil.FUTURE_DATE.isoformat(), - ) - @pytest.mark.parametrize('arglist,fy', testutil.combine_values( [['--report-type', 'text'], ['--format=text'], ['-f', 'txt']], range(2018, 2021), @@ -362,7 +225,11 @@ def test_ods_is_empty(): def test_ods_output(fy, account, amt_prefix): books_path = testutil.test_path(f'books/books/{fy}.beancount') config = testutil.TestConfig(books_path=books_path) - arglist = ['-O', '-', '-f', 'ods', f'account ~ "^{account}:"'] + arglist = [ + '-O', '-', + '-f', 'ods', + f'SELECT date, narration, UNITS(position) WHERE account ~ "^{account}:"', + ] returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO) assert returncode == 0 stdout.seek(0) @@ -371,7 +238,7 @@ def test_ods_output(fy, account, amt_prefix): next(rows) # Skip header row amt_pattern = rf'^{re.escape(amt_prefix)}\d' for count, row in enumerate(rows, 1): - date, entity, narration, amount = row.childNodes + date, narration, amount = row.childNodes assert re.fullmatch(rf'{fy}-\d{{2}}-\d{{2}}', date.text) assert narration.text.startswith(f'{fy} ') assert re.match(amt_pattern, amount.text)