From 221d42a4791ae973d0272656da0129d70b20ae7a Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Fri, 5 Mar 2021 16:45:24 -0500 Subject: [PATCH] query: Add --select to use with WHERE condition arguments. --- conservancy_beancount/reports/query.py | 85 ++++++++++++++++++++------ tests/test_reports_query.py | 41 ++++++++++++- 2 files changed, 106 insertions(+), 20 deletions(-) diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index 302423f..c870829 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -6,17 +6,16 @@ # LICENSE.txt in the repository. import argparse -import collections import datetime import enum -import functools -import locale +import itertools import logging import re import sys from typing import ( cast, + AbstractSet, Callable, Dict, Iterable, @@ -37,8 +36,9 @@ from ..beancount_types import ( from decimal import Decimal from pathlib import Path -import beancount.query.shell as bc_query +import beancount.query.query_env as bc_query_env import beancount.query.query_parser as bc_query_parser +import beancount.query.shell as bc_query from . import core from . import rewrite @@ -47,6 +47,10 @@ from .. import cliutil from .. import config as configmod from .. import data +BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain( + bc_query_env.TargetsEnvironment.columns, # type:ignore[has-type] + bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type] +)) PROGNAME = 'query-report' QUERY_PARSER = bc_query_parser.Parser() logger = logging.getLogger('conservancy_beancount.reports.query') @@ -119,24 +123,43 @@ def build_query( ) -> Optional[str]: if not args.query: args.query = [] if in_file is None else [line[:-1] for line in in_file] - if not any(re.search(r'\S', s) for s in args.query): - return None plain_query = ' '.join(args.query) + if not plain_query or plain_query.isspace(): + return None try: QUERY_PARSER.parse(plain_query) except bc_query_parser.ParseError: + if args.join is None: + args.join = JoinOperator.AND + select = [ + 'date', + 'ANY_META("entity") as entity', + 'narration', + 'position', + 'COST(position)', + *(f'ANY_META("{field}") AS {field.replace("-", "_")}' + if field not in BUILTIN_FIELDS + and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field) + else field + for field in args.select), + ] conds = [f'({args.join.join(args.query)})'] if args.start_date is not None: conds.append(_date_condition(args.start_date, fy.first_date, '>=')) if args.stop_date is not None: conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<')) - return f'SELECT * WHERE {" AND ".join(conds)}' + return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}' else: + if args.join: + raise ValueError("cannot specify --join with a full query") + if args.select: + raise ValueError("cannot specify --select with a full query") return plain_query def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: parser = argparse.ArgumentParser(prog=PROGNAME) cliutil.add_version_argument(parser) + cliutil.add_loglevel_argument(parser) parser.add_argument( '--begin', '--start', '-b', dest='start_date', @@ -172,25 +195,47 @@ extension, or 'text' if that fails. help="""Write the report to this file, or stdout when PATH is `-`. The default is stdout for text and CSV reports, and a generated filename for ODS reports. +""") + + query_group = parser.add_argument_group("query options", """ +You can write a single full query as a command line argument (like bean-query), +or you can write individual WHERE condition(s) as arguments. If you write +WHERE conditions, these options are used to build the rest of the query. """) join_arg = cliutil.EnumArgument(JoinOperator) - parser.add_argument( + query_group.add_argument( + '--select', '-s', + metavar='COLUMN', + default=[], + action=cliutil.ExtendAction, + help="""Columns to select. You can write these as comma-separated +names, and/or specify the option more than once. You can specify both +bean-query's built-in column names (like `account` and `flag`) and metadata +keys. +""") + query_group.add_argument( + '--group-by', '-g', + metavar='COLUMN', + help="""Group output by this column +""") + # query_group.add_argument( + # '--order-by', '--sort', '-r', + # metavar='COLUMN', + # help="""Order output by this column + # """), + query_group.add_argument( '--join', '-j', metavar='OP', type=join_arg.enum_type, - default=JoinOperator.AND, - help="""When you specify multiple WHERE conditions on the command line -and let query-report build the query, join conditions with this operator. + help=f"""Join your WHERE conditions with this operator. Choices are {join_arg.choices_str()}. Default 'and'. -""") - cliutil.add_loglevel_argument(parser) - parser.add_argument( +"""), + query_group.add_argument( 'query', nargs=argparse.ZERO_OR_MORE, - help="""Query to run non-interactively. You can specify a full query -you write yourself, or conditions to follow WHERE and let query-report build -the rest of the query. + help="""Full query or WHERE conditions to run non-interactively """) + args = parser.parse_args(arglist) return args @@ -208,7 +253,11 @@ def main(arglist: Optional[Sequence[str]]=None, fy = config.fiscal_year_begin() if args.stop_date is None and args.start_date is not None: args.stop_date = fy.next_fy_date(args.start_date) - query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin) + try: + query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin) + except ValueError as error: + logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG)) + return 2 is_interactive = query is None and sys.stdin.isatty() if args.report_type is None: try: diff --git a/tests/test_reports_query.py b/tests/test_reports_query.py index 48efd66..9e359e6 100644 --- a/tests/test_reports_query.py +++ b/tests/test_reports_query.py @@ -44,8 +44,13 @@ def pipe_main(arglist, config): returncode = qmod.main(arglist, stdout, stderr, config) return returncode, stdout, stderr -def query_args(query=None, start_date=None, stop_date=None, join='AND'): - join = qmod.JoinOperator[join] +def query_args(query=None, start_date=None, stop_date=None, join=None, select=None): + if isinstance(join, str): + join = qmod.JoinOperator[join] + if select is None: + select = [] + elif isinstance(select, str): + select = select.split(',') return argparse.Namespace(**locals()) def test_books_loader_empty(): @@ -99,6 +104,17 @@ def test_build_query_in_arglist(fy, query_str): args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE) assert qmod.build_query(args, fy) == query_str +@pytest.mark.parametrize('argname,argval', [ + ('join', qmod.JoinOperator.AND), + ('join', qmod.JoinOperator.OR), + ('select', 'date,flag'), + ('select', 'position,rt-id'), +]) +def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval): + args = query_args(['journal'], **{argname: argval}) + with pytest.raises(ValueError): + qmod.build_query(args, fy) + @pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1)) def test_build_query_where_arglist_conditions(fy, count, join_op): conds = ['account ~ "^Income:"', 'year >= 2018'][:count] @@ -108,6 +124,27 @@ def test_build_query_where_arglist_conditions(fy, count, join_op): cond_index = query.index(' WHERE ') + 7 assert query[cond_index:] == '({})'.format(join_op.join(conds)) +@pytest.mark.parametrize('select', [ + ['flag'], + ['check'], + ['flag', 'month'], + ['cost_label', 'cost_metakey'], + ['approval', 'receipt'], +]) +def test_build_query_select_fields(fy, select): + args = query_args(['year>2018'], select=list(select)) + query = qmod.build_query(args, fy) + assert query.startswith('SELECT ') + start_index = 7 + for field in select: + if field != 'flag' and field != 'month' and field != 'cost_label': + field = f'ANY_META("{field}") AS {field.replace("-", "_")}' + match = re.search(rf',\s*{re.escape(field)}\b', query) + assert match, f"field {field!r} not found in query: {query!r}" + assert match.start() >= start_index + start_index = match.end() + assert query[start_index:start_index + 7] == ' WHERE ' + @pytest.mark.parametrize('argname,date_arg', itertools.product( ['start_date', 'stop_date'], [testutil.FY_START_DATE, testutil.FY_START_DATE.year],