query: Add --select to use with WHERE condition arguments.

This commit is contained in:
Brett Smith 2021-03-05 16:45:24 -05:00
parent 1ca7cccf17
commit 221d42a479
2 changed files with 106 additions and 20 deletions

View file

@ -6,17 +6,16 @@
# LICENSE.txt in the repository. # LICENSE.txt in the repository.
import argparse import argparse
import collections
import datetime import datetime
import enum import enum
import functools import itertools
import locale
import logging import logging
import re import re
import sys import sys
from typing import ( from typing import (
cast, cast,
AbstractSet,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -37,8 +36,9 @@ from ..beancount_types import (
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
import beancount.query.shell as bc_query import beancount.query.query_env as bc_query_env
import beancount.query.query_parser as bc_query_parser import beancount.query.query_parser as bc_query_parser
import beancount.query.shell as bc_query
from . import core from . import core
from . import rewrite from . import rewrite
@ -47,6 +47,10 @@ from .. import cliutil
from .. import config as configmod from .. import config as configmod
from .. import data from .. import data
BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
bc_query_env.TargetsEnvironment.columns, # type:ignore[has-type]
bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type]
))
PROGNAME = 'query-report' PROGNAME = 'query-report'
QUERY_PARSER = bc_query_parser.Parser() QUERY_PARSER = bc_query_parser.Parser()
logger = logging.getLogger('conservancy_beancount.reports.query') logger = logging.getLogger('conservancy_beancount.reports.query')
@ -119,24 +123,43 @@ def build_query(
) -> Optional[str]: ) -> Optional[str]:
if not args.query: if not args.query:
args.query = [] if in_file is None else [line[:-1] for line in in_file] args.query = [] if in_file is None else [line[:-1] for line in in_file]
if not any(re.search(r'\S', s) for s in args.query):
return None
plain_query = ' '.join(args.query) plain_query = ' '.join(args.query)
if not plain_query or plain_query.isspace():
return None
try: try:
QUERY_PARSER.parse(plain_query) QUERY_PARSER.parse(plain_query)
except bc_query_parser.ParseError: except bc_query_parser.ParseError:
if args.join is None:
args.join = JoinOperator.AND
select = [
'date',
'ANY_META("entity") as entity',
'narration',
'position',
'COST(position)',
*(f'ANY_META("{field}") AS {field.replace("-", "_")}'
if field not in BUILTIN_FIELDS
and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field)
else field
for field in args.select),
]
conds = [f'({args.join.join(args.query)})'] conds = [f'({args.join.join(args.query)})']
if args.start_date is not None: if args.start_date is not None:
conds.append(_date_condition(args.start_date, fy.first_date, '>=')) conds.append(_date_condition(args.start_date, fy.first_date, '>='))
if args.stop_date is not None: if args.stop_date is not None:
conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<')) conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
return f'SELECT * WHERE {" AND ".join(conds)}' return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}'
else: else:
if args.join:
raise ValueError("cannot specify --join with a full query")
if args.select:
raise ValueError("cannot specify --select with a full query")
return plain_query return plain_query
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog=PROGNAME) parser = argparse.ArgumentParser(prog=PROGNAME)
cliutil.add_version_argument(parser) cliutil.add_version_argument(parser)
cliutil.add_loglevel_argument(parser)
parser.add_argument( parser.add_argument(
'--begin', '--start', '-b', '--begin', '--start', '-b',
dest='start_date', dest='start_date',
@ -172,25 +195,47 @@ extension, or 'text' if that fails.
help="""Write the report to this file, or stdout when PATH is `-`. help="""Write the report to this file, or stdout when PATH is `-`.
The default is stdout for text and CSV reports, and a generated filename for The default is stdout for text and CSV reports, and a generated filename for
ODS reports. ODS reports.
""")
query_group = parser.add_argument_group("query options", """
You can write a single full query as a command line argument (like bean-query),
or you can write individual WHERE condition(s) as arguments. If you write
WHERE conditions, these options are used to build the rest of the query.
""") """)
join_arg = cliutil.EnumArgument(JoinOperator) join_arg = cliutil.EnumArgument(JoinOperator)
parser.add_argument( query_group.add_argument(
'--select', '-s',
metavar='COLUMN',
default=[],
action=cliutil.ExtendAction,
help="""Columns to select. You can write these as comma-separated
names, and/or specify the option more than once. You can specify both
bean-query's built-in column names (like `account` and `flag`) and metadata
keys.
""")
query_group.add_argument(
'--group-by', '-g',
metavar='COLUMN',
help="""Group output by this column
""")
# query_group.add_argument(
# '--order-by', '--sort', '-r',
# metavar='COLUMN',
# help="""Order output by this column
# """),
query_group.add_argument(
'--join', '-j', '--join', '-j',
metavar='OP', metavar='OP',
type=join_arg.enum_type, type=join_arg.enum_type,
default=JoinOperator.AND, help=f"""Join your WHERE conditions with this operator.
help="""When you specify multiple WHERE conditions on the command line
and let query-report build the query, join conditions with this operator.
Choices are {join_arg.choices_str()}. Default 'and'. Choices are {join_arg.choices_str()}. Default 'and'.
""") """),
cliutil.add_loglevel_argument(parser) query_group.add_argument(
parser.add_argument(
'query', 'query',
nargs=argparse.ZERO_OR_MORE, nargs=argparse.ZERO_OR_MORE,
help="""Query to run non-interactively. You can specify a full query help="""Full query or WHERE conditions to run non-interactively
you write yourself, or conditions to follow WHERE and let query-report build
the rest of the query.
""") """)
args = parser.parse_args(arglist) args = parser.parse_args(arglist)
return args return args
@ -208,7 +253,11 @@ def main(arglist: Optional[Sequence[str]]=None,
fy = config.fiscal_year_begin() fy = config.fiscal_year_begin()
if args.stop_date is None and args.start_date is not None: if args.stop_date is None and args.start_date is not None:
args.stop_date = fy.next_fy_date(args.start_date) args.stop_date = fy.next_fy_date(args.start_date)
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin) try:
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
except ValueError as error:
logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG))
return 2
is_interactive = query is None and sys.stdin.isatty() is_interactive = query is None and sys.stdin.isatty()
if args.report_type is None: if args.report_type is None:
try: try:

View file

@ -44,8 +44,13 @@ def pipe_main(arglist, config):
returncode = qmod.main(arglist, stdout, stderr, config) returncode = qmod.main(arglist, stdout, stderr, config)
return returncode, stdout, stderr return returncode, stdout, stderr
def query_args(query=None, start_date=None, stop_date=None, join='AND'): def query_args(query=None, start_date=None, stop_date=None, join=None, select=None):
join = qmod.JoinOperator[join] if isinstance(join, str):
join = qmod.JoinOperator[join]
if select is None:
select = []
elif isinstance(select, str):
select = select.split(',')
return argparse.Namespace(**locals()) return argparse.Namespace(**locals())
def test_books_loader_empty(): def test_books_loader_empty():
@ -99,6 +104,17 @@ def test_build_query_in_arglist(fy, query_str):
args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE) args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
assert qmod.build_query(args, fy) == query_str assert qmod.build_query(args, fy) == query_str
@pytest.mark.parametrize('argname,argval', [
('join', qmod.JoinOperator.AND),
('join', qmod.JoinOperator.OR),
('select', 'date,flag'),
('select', 'position,rt-id'),
])
def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval):
args = query_args(['journal'], **{argname: argval})
with pytest.raises(ValueError):
qmod.build_query(args, fy)
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1)) @pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
def test_build_query_where_arglist_conditions(fy, count, join_op): def test_build_query_where_arglist_conditions(fy, count, join_op):
conds = ['account ~ "^Income:"', 'year >= 2018'][:count] conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
@ -108,6 +124,27 @@ def test_build_query_where_arglist_conditions(fy, count, join_op):
cond_index = query.index(' WHERE ') + 7 cond_index = query.index(' WHERE ') + 7
assert query[cond_index:] == '({})'.format(join_op.join(conds)) assert query[cond_index:] == '({})'.format(join_op.join(conds))
@pytest.mark.parametrize('select', [
['flag'],
['check'],
['flag', 'month'],
['cost_label', 'cost_metakey'],
['approval', 'receipt'],
])
def test_build_query_select_fields(fy, select):
args = query_args(['year>2018'], select=list(select))
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
start_index = 7
for field in select:
if field != 'flag' and field != 'month' and field != 'cost_label':
field = f'ANY_META("{field}") AS {field.replace("-", "_")}'
match = re.search(rf',\s*{re.escape(field)}\b', query)
assert match, f"field {field!r} not found in query: {query!r}"
assert match.start() >= start_index
start_index = match.end()
assert query[start_index:start_index + 7] == ' WHERE '
@pytest.mark.parametrize('argname,date_arg', itertools.product( @pytest.mark.parametrize('argname,date_arg', itertools.product(
['start_date', 'stop_date'], ['start_date', 'stop_date'],
[testutil.FY_START_DATE, testutil.FY_START_DATE.year], [testutil.FY_START_DATE, testutil.FY_START_DATE.year],