query: Add --select to use with WHERE condition arguments.

This commit is contained in:
Brett Smith 2021-03-05 16:45:24 -05:00
parent 1ca7cccf17
commit 221d42a479
2 changed files with 106 additions and 20 deletions

View file

@ -6,17 +6,16 @@
# LICENSE.txt in the repository.
import argparse
import collections
import datetime
import enum
import functools
import locale
import itertools
import logging
import re
import sys
from typing import (
cast,
AbstractSet,
Callable,
Dict,
Iterable,
@ -37,8 +36,9 @@ from ..beancount_types import (
from decimal import Decimal
from pathlib import Path
import beancount.query.shell as bc_query
import beancount.query.query_env as bc_query_env
import beancount.query.query_parser as bc_query_parser
import beancount.query.shell as bc_query
from . import core
from . import rewrite
@ -47,6 +47,10 @@ from .. import cliutil
from .. import config as configmod
from .. import data
BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
bc_query_env.TargetsEnvironment.columns, # type:ignore[has-type]
bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type]
))
PROGNAME = 'query-report'
QUERY_PARSER = bc_query_parser.Parser()
logger = logging.getLogger('conservancy_beancount.reports.query')
@ -119,24 +123,43 @@ def build_query(
) -> Optional[str]:
if not args.query:
args.query = [] if in_file is None else [line[:-1] for line in in_file]
if not any(re.search(r'\S', s) for s in args.query):
return None
plain_query = ' '.join(args.query)
if not plain_query or plain_query.isspace():
return None
try:
QUERY_PARSER.parse(plain_query)
except bc_query_parser.ParseError:
if args.join is None:
args.join = JoinOperator.AND
select = [
'date',
'ANY_META("entity") as entity',
'narration',
'position',
'COST(position)',
*(f'ANY_META("{field}") AS {field.replace("-", "_")}'
if field not in BUILTIN_FIELDS
and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field)
else field
for field in args.select),
]
conds = [f'({args.join.join(args.query)})']
if args.start_date is not None:
conds.append(_date_condition(args.start_date, fy.first_date, '>='))
if args.stop_date is not None:
conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
return f'SELECT * WHERE {" AND ".join(conds)}'
return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}'
else:
if args.join:
raise ValueError("cannot specify --join with a full query")
if args.select:
raise ValueError("cannot specify --select with a full query")
return plain_query
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog=PROGNAME)
cliutil.add_version_argument(parser)
cliutil.add_loglevel_argument(parser)
parser.add_argument(
'--begin', '--start', '-b',
dest='start_date',
@ -172,25 +195,47 @@ extension, or 'text' if that fails.
help="""Write the report to this file, or stdout when PATH is `-`.
The default is stdout for text and CSV reports, and a generated filename for
ODS reports.
""")
query_group = parser.add_argument_group("query options", """
You can write a single full query as a command line argument (like bean-query),
or you can write individual WHERE condition(s) as arguments. If you write
WHERE conditions, these options are used to build the rest of the query.
""")
join_arg = cliutil.EnumArgument(JoinOperator)
parser.add_argument(
query_group.add_argument(
'--select', '-s',
metavar='COLUMN',
default=[],
action=cliutil.ExtendAction,
help="""Columns to select. You can write these as comma-separated
names, and/or specify the option more than once. You can specify both
bean-query's built-in column names (like `account` and `flag`) and metadata
keys.
""")
query_group.add_argument(
'--group-by', '-g',
metavar='COLUMN',
help="""Group output by this column
""")
# query_group.add_argument(
# '--order-by', '--sort', '-r',
# metavar='COLUMN',
# help="""Order output by this column
# """),
query_group.add_argument(
'--join', '-j',
metavar='OP',
type=join_arg.enum_type,
default=JoinOperator.AND,
help="""When you specify multiple WHERE conditions on the command line
and let query-report build the query, join conditions with this operator.
help=f"""Join your WHERE conditions with this operator.
Choices are {join_arg.choices_str()}. Default 'and'.
""")
cliutil.add_loglevel_argument(parser)
parser.add_argument(
"""),
query_group.add_argument(
'query',
nargs=argparse.ZERO_OR_MORE,
help="""Query to run non-interactively. You can specify a full query
you write yourself, or conditions to follow WHERE and let query-report build
the rest of the query.
help="""Full query or WHERE conditions to run non-interactively
""")
args = parser.parse_args(arglist)
return args
@ -208,7 +253,11 @@ def main(arglist: Optional[Sequence[str]]=None,
fy = config.fiscal_year_begin()
if args.stop_date is None and args.start_date is not None:
args.stop_date = fy.next_fy_date(args.start_date)
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
try:
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
except ValueError as error:
logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG))
return 2
is_interactive = query is None and sys.stdin.isatty()
if args.report_type is None:
try:

View file

@ -44,8 +44,13 @@ def pipe_main(arglist, config):
returncode = qmod.main(arglist, stdout, stderr, config)
return returncode, stdout, stderr
def query_args(query=None, start_date=None, stop_date=None, join='AND'):
join = qmod.JoinOperator[join]
def query_args(query=None, start_date=None, stop_date=None, join=None, select=None):
if isinstance(join, str):
join = qmod.JoinOperator[join]
if select is None:
select = []
elif isinstance(select, str):
select = select.split(',')
return argparse.Namespace(**locals())
def test_books_loader_empty():
@ -99,6 +104,17 @@ def test_build_query_in_arglist(fy, query_str):
args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
assert qmod.build_query(args, fy) == query_str
@pytest.mark.parametrize('argname,argval', [
('join', qmod.JoinOperator.AND),
('join', qmod.JoinOperator.OR),
('select', 'date,flag'),
('select', 'position,rt-id'),
])
def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval):
args = query_args(['journal'], **{argname: argval})
with pytest.raises(ValueError):
qmod.build_query(args, fy)
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
def test_build_query_where_arglist_conditions(fy, count, join_op):
conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
@ -108,6 +124,27 @@ def test_build_query_where_arglist_conditions(fy, count, join_op):
cond_index = query.index(' WHERE ') + 7
assert query[cond_index:] == '({})'.format(join_op.join(conds))
@pytest.mark.parametrize('select', [
['flag'],
['check'],
['flag', 'month'],
['cost_label', 'cost_metakey'],
['approval', 'receipt'],
])
def test_build_query_select_fields(fy, select):
args = query_args(['year>2018'], select=list(select))
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
start_index = 7
for field in select:
if field != 'flag' and field != 'month' and field != 'cost_label':
field = f'ANY_META("{field}") AS {field.replace("-", "_")}'
match = re.search(rf',\s*{re.escape(field)}\b', query)
assert match, f"field {field!r} not found in query: {query!r}"
assert match.start() >= start_index
start_index = match.end()
assert query[start_index:start_index + 7] == ' WHERE '
@pytest.mark.parametrize('argname,date_arg', itertools.product(
['start_date', 'stop_date'],
[testutil.FY_START_DATE, testutil.FY_START_DATE.year],