query: Add --select to use with WHERE condition arguments.
This commit is contained in:
parent
1ca7cccf17
commit
221d42a479
2 changed files with 106 additions and 20 deletions
|
@ -6,17 +6,16 @@
|
|||
# LICENSE.txt in the repository.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
import locale
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
from typing import (
|
||||
cast,
|
||||
AbstractSet,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
|
@ -37,8 +36,9 @@ from ..beancount_types import (
|
|||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import beancount.query.shell as bc_query
|
||||
import beancount.query.query_env as bc_query_env
|
||||
import beancount.query.query_parser as bc_query_parser
|
||||
import beancount.query.shell as bc_query
|
||||
|
||||
from . import core
|
||||
from . import rewrite
|
||||
|
@ -47,6 +47,10 @@ from .. import cliutil
|
|||
from .. import config as configmod
|
||||
from .. import data
|
||||
|
||||
BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
|
||||
bc_query_env.TargetsEnvironment.columns, # type:ignore[has-type]
|
||||
bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type]
|
||||
))
|
||||
PROGNAME = 'query-report'
|
||||
QUERY_PARSER = bc_query_parser.Parser()
|
||||
logger = logging.getLogger('conservancy_beancount.reports.query')
|
||||
|
@ -119,24 +123,43 @@ def build_query(
|
|||
) -> Optional[str]:
|
||||
if not args.query:
|
||||
args.query = [] if in_file is None else [line[:-1] for line in in_file]
|
||||
if not any(re.search(r'\S', s) for s in args.query):
|
||||
return None
|
||||
plain_query = ' '.join(args.query)
|
||||
if not plain_query or plain_query.isspace():
|
||||
return None
|
||||
try:
|
||||
QUERY_PARSER.parse(plain_query)
|
||||
except bc_query_parser.ParseError:
|
||||
if args.join is None:
|
||||
args.join = JoinOperator.AND
|
||||
select = [
|
||||
'date',
|
||||
'ANY_META("entity") as entity',
|
||||
'narration',
|
||||
'position',
|
||||
'COST(position)',
|
||||
*(f'ANY_META("{field}") AS {field.replace("-", "_")}'
|
||||
if field not in BUILTIN_FIELDS
|
||||
and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field)
|
||||
else field
|
||||
for field in args.select),
|
||||
]
|
||||
conds = [f'({args.join.join(args.query)})']
|
||||
if args.start_date is not None:
|
||||
conds.append(_date_condition(args.start_date, fy.first_date, '>='))
|
||||
if args.stop_date is not None:
|
||||
conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
|
||||
return f'SELECT * WHERE {" AND ".join(conds)}'
|
||||
return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}'
|
||||
else:
|
||||
if args.join:
|
||||
raise ValueError("cannot specify --join with a full query")
|
||||
if args.select:
|
||||
raise ValueError("cannot specify --select with a full query")
|
||||
return plain_query
|
||||
|
||||
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(prog=PROGNAME)
|
||||
cliutil.add_version_argument(parser)
|
||||
cliutil.add_loglevel_argument(parser)
|
||||
parser.add_argument(
|
||||
'--begin', '--start', '-b',
|
||||
dest='start_date',
|
||||
|
@ -172,25 +195,47 @@ extension, or 'text' if that fails.
|
|||
help="""Write the report to this file, or stdout when PATH is `-`.
|
||||
The default is stdout for text and CSV reports, and a generated filename for
|
||||
ODS reports.
|
||||
""")
|
||||
|
||||
query_group = parser.add_argument_group("query options", """
|
||||
You can write a single full query as a command line argument (like bean-query),
|
||||
or you can write individual WHERE condition(s) as arguments. If you write
|
||||
WHERE conditions, these options are used to build the rest of the query.
|
||||
""")
|
||||
join_arg = cliutil.EnumArgument(JoinOperator)
|
||||
parser.add_argument(
|
||||
query_group.add_argument(
|
||||
'--select', '-s',
|
||||
metavar='COLUMN',
|
||||
default=[],
|
||||
action=cliutil.ExtendAction,
|
||||
help="""Columns to select. You can write these as comma-separated
|
||||
names, and/or specify the option more than once. You can specify both
|
||||
bean-query's built-in column names (like `account` and `flag`) and metadata
|
||||
keys.
|
||||
""")
|
||||
query_group.add_argument(
|
||||
'--group-by', '-g',
|
||||
metavar='COLUMN',
|
||||
help="""Group output by this column
|
||||
""")
|
||||
# query_group.add_argument(
|
||||
# '--order-by', '--sort', '-r',
|
||||
# metavar='COLUMN',
|
||||
# help="""Order output by this column
|
||||
# """),
|
||||
query_group.add_argument(
|
||||
'--join', '-j',
|
||||
metavar='OP',
|
||||
type=join_arg.enum_type,
|
||||
default=JoinOperator.AND,
|
||||
help="""When you specify multiple WHERE conditions on the command line
|
||||
and let query-report build the query, join conditions with this operator.
|
||||
help=f"""Join your WHERE conditions with this operator.
|
||||
Choices are {join_arg.choices_str()}. Default 'and'.
|
||||
""")
|
||||
cliutil.add_loglevel_argument(parser)
|
||||
parser.add_argument(
|
||||
"""),
|
||||
query_group.add_argument(
|
||||
'query',
|
||||
nargs=argparse.ZERO_OR_MORE,
|
||||
help="""Query to run non-interactively. You can specify a full query
|
||||
you write yourself, or conditions to follow WHERE and let query-report build
|
||||
the rest of the query.
|
||||
help="""Full query or WHERE conditions to run non-interactively
|
||||
""")
|
||||
|
||||
args = parser.parse_args(arglist)
|
||||
return args
|
||||
|
||||
|
@ -208,7 +253,11 @@ def main(arglist: Optional[Sequence[str]]=None,
|
|||
fy = config.fiscal_year_begin()
|
||||
if args.stop_date is None and args.start_date is not None:
|
||||
args.stop_date = fy.next_fy_date(args.start_date)
|
||||
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
|
||||
try:
|
||||
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
|
||||
except ValueError as error:
|
||||
logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
return 2
|
||||
is_interactive = query is None and sys.stdin.isatty()
|
||||
if args.report_type is None:
|
||||
try:
|
||||
|
|
|
@ -44,8 +44,13 @@ def pipe_main(arglist, config):
|
|||
returncode = qmod.main(arglist, stdout, stderr, config)
|
||||
return returncode, stdout, stderr
|
||||
|
||||
def query_args(query=None, start_date=None, stop_date=None, join='AND'):
|
||||
join = qmod.JoinOperator[join]
|
||||
def query_args(query=None, start_date=None, stop_date=None, join=None, select=None):
|
||||
if isinstance(join, str):
|
||||
join = qmod.JoinOperator[join]
|
||||
if select is None:
|
||||
select = []
|
||||
elif isinstance(select, str):
|
||||
select = select.split(',')
|
||||
return argparse.Namespace(**locals())
|
||||
|
||||
def test_books_loader_empty():
|
||||
|
@ -99,6 +104,17 @@ def test_build_query_in_arglist(fy, query_str):
|
|||
args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
|
||||
assert qmod.build_query(args, fy) == query_str
|
||||
|
||||
@pytest.mark.parametrize('argname,argval', [
|
||||
('join', qmod.JoinOperator.AND),
|
||||
('join', qmod.JoinOperator.OR),
|
||||
('select', 'date,flag'),
|
||||
('select', 'position,rt-id'),
|
||||
])
|
||||
def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval):
|
||||
args = query_args(['journal'], **{argname: argval})
|
||||
with pytest.raises(ValueError):
|
||||
qmod.build_query(args, fy)
|
||||
|
||||
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
|
||||
def test_build_query_where_arglist_conditions(fy, count, join_op):
|
||||
conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
|
||||
|
@ -108,6 +124,27 @@ def test_build_query_where_arglist_conditions(fy, count, join_op):
|
|||
cond_index = query.index(' WHERE ') + 7
|
||||
assert query[cond_index:] == '({})'.format(join_op.join(conds))
|
||||
|
||||
@pytest.mark.parametrize('select', [
|
||||
['flag'],
|
||||
['check'],
|
||||
['flag', 'month'],
|
||||
['cost_label', 'cost_metakey'],
|
||||
['approval', 'receipt'],
|
||||
])
|
||||
def test_build_query_select_fields(fy, select):
|
||||
args = query_args(['year>2018'], select=list(select))
|
||||
query = qmod.build_query(args, fy)
|
||||
assert query.startswith('SELECT ')
|
||||
start_index = 7
|
||||
for field in select:
|
||||
if field != 'flag' and field != 'month' and field != 'cost_label':
|
||||
field = f'ANY_META("{field}") AS {field.replace("-", "_")}'
|
||||
match = re.search(rf',\s*{re.escape(field)}\b', query)
|
||||
assert match, f"field {field!r} not found in query: {query!r}"
|
||||
assert match.start() >= start_index
|
||||
start_index = match.end()
|
||||
assert query[start_index:start_index + 7] == ' WHERE '
|
||||
|
||||
@pytest.mark.parametrize('argname,date_arg', itertools.product(
|
||||
['start_date', 'stop_date'],
|
||||
[testutil.FY_START_DATE, testutil.FY_START_DATE.year],
|
||||
|
|
Loading…
Reference in a new issue