query: Add --select to use with WHERE condition arguments.
This commit is contained in:
		
							parent
							
								
									1ca7cccf17
								
							
						
					
					
						commit
						221d42a479
					
				
					 2 changed files with 106 additions and 20 deletions
				
			
		| 
						 | 
					@ -6,17 +6,16 @@
 | 
				
			||||||
# LICENSE.txt in the repository.
 | 
					# LICENSE.txt in the repository.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import collections
 | 
					 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
import functools
 | 
					import itertools
 | 
				
			||||||
import locale
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    cast,
 | 
					    cast,
 | 
				
			||||||
 | 
					    AbstractSet,
 | 
				
			||||||
    Callable,
 | 
					    Callable,
 | 
				
			||||||
    Dict,
 | 
					    Dict,
 | 
				
			||||||
    Iterable,
 | 
					    Iterable,
 | 
				
			||||||
| 
						 | 
					@ -37,8 +36,9 @@ from ..beancount_types import (
 | 
				
			||||||
from decimal import Decimal
 | 
					from decimal import Decimal
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import beancount.query.shell as bc_query
 | 
					import beancount.query.query_env as bc_query_env
 | 
				
			||||||
import beancount.query.query_parser as bc_query_parser
 | 
					import beancount.query.query_parser as bc_query_parser
 | 
				
			||||||
 | 
					import beancount.query.shell as bc_query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import core
 | 
					from . import core
 | 
				
			||||||
from . import rewrite
 | 
					from . import rewrite
 | 
				
			||||||
| 
						 | 
					@ -47,6 +47,10 @@ from .. import cliutil
 | 
				
			||||||
from .. import config as configmod
 | 
					from .. import config as configmod
 | 
				
			||||||
from .. import data
 | 
					from .. import data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
 | 
				
			||||||
 | 
					    bc_query_env.TargetsEnvironment.columns,  # type:ignore[has-type]
 | 
				
			||||||
 | 
					    bc_query_env.TargetsEnvironment.functions,  # type:ignore[has-type]
 | 
				
			||||||
 | 
					))
 | 
				
			||||||
PROGNAME = 'query-report'
 | 
					PROGNAME = 'query-report'
 | 
				
			||||||
QUERY_PARSER = bc_query_parser.Parser()
 | 
					QUERY_PARSER = bc_query_parser.Parser()
 | 
				
			||||||
logger = logging.getLogger('conservancy_beancount.reports.query')
 | 
					logger = logging.getLogger('conservancy_beancount.reports.query')
 | 
				
			||||||
| 
						 | 
					@ -119,24 +123,43 @@ def build_query(
 | 
				
			||||||
) -> Optional[str]:
 | 
					) -> Optional[str]:
 | 
				
			||||||
    if not args.query:
 | 
					    if not args.query:
 | 
				
			||||||
        args.query = [] if in_file is None else [line[:-1] for line in in_file]
 | 
					        args.query = [] if in_file is None else [line[:-1] for line in in_file]
 | 
				
			||||||
        if not any(re.search(r'\S', s) for s in args.query):
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
    plain_query = ' '.join(args.query)
 | 
					    plain_query = ' '.join(args.query)
 | 
				
			||||||
 | 
					    if not plain_query or plain_query.isspace():
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        QUERY_PARSER.parse(plain_query)
 | 
					        QUERY_PARSER.parse(plain_query)
 | 
				
			||||||
    except bc_query_parser.ParseError:
 | 
					    except bc_query_parser.ParseError:
 | 
				
			||||||
 | 
					        if args.join is None:
 | 
				
			||||||
 | 
					            args.join = JoinOperator.AND
 | 
				
			||||||
 | 
					        select = [
 | 
				
			||||||
 | 
					            'date',
 | 
				
			||||||
 | 
					            'ANY_META("entity") as entity',
 | 
				
			||||||
 | 
					            'narration',
 | 
				
			||||||
 | 
					            'position',
 | 
				
			||||||
 | 
					            'COST(position)',
 | 
				
			||||||
 | 
					            *(f'ANY_META("{field}") AS {field.replace("-", "_")}'
 | 
				
			||||||
 | 
					              if field not in BUILTIN_FIELDS
 | 
				
			||||||
 | 
					              and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field)
 | 
				
			||||||
 | 
					              else field
 | 
				
			||||||
 | 
					              for field in args.select),
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
        conds = [f'({args.join.join(args.query)})']
 | 
					        conds = [f'({args.join.join(args.query)})']
 | 
				
			||||||
        if args.start_date is not None:
 | 
					        if args.start_date is not None:
 | 
				
			||||||
            conds.append(_date_condition(args.start_date, fy.first_date, '>='))
 | 
					            conds.append(_date_condition(args.start_date, fy.first_date, '>='))
 | 
				
			||||||
        if args.stop_date is not None:
 | 
					        if args.stop_date is not None:
 | 
				
			||||||
            conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
 | 
					            conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
 | 
				
			||||||
        return f'SELECT * WHERE {" AND ".join(conds)}'
 | 
					        return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}'
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
 | 
					        if args.join:
 | 
				
			||||||
 | 
					            raise ValueError("cannot specify --join with a full query")
 | 
				
			||||||
 | 
					        if args.select:
 | 
				
			||||||
 | 
					            raise ValueError("cannot specify --select with a full query")
 | 
				
			||||||
        return plain_query
 | 
					        return plain_query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 | 
					def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
 | 
				
			||||||
    parser = argparse.ArgumentParser(prog=PROGNAME)
 | 
					    parser = argparse.ArgumentParser(prog=PROGNAME)
 | 
				
			||||||
    cliutil.add_version_argument(parser)
 | 
					    cliutil.add_version_argument(parser)
 | 
				
			||||||
 | 
					    cliutil.add_loglevel_argument(parser)
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        '--begin', '--start', '-b',
 | 
					        '--begin', '--start', '-b',
 | 
				
			||||||
        dest='start_date',
 | 
					        dest='start_date',
 | 
				
			||||||
| 
						 | 
					@ -172,25 +195,47 @@ extension, or 'text' if that fails.
 | 
				
			||||||
        help="""Write the report to this file, or stdout when PATH is `-`.
 | 
					        help="""Write the report to this file, or stdout when PATH is `-`.
 | 
				
			||||||
The default is stdout for text and CSV reports, and a generated filename for
 | 
					The default is stdout for text and CSV reports, and a generated filename for
 | 
				
			||||||
ODS reports.
 | 
					ODS reports.
 | 
				
			||||||
 | 
					""")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_group = parser.add_argument_group("query options", """
 | 
				
			||||||
 | 
					You can write a single full query as a command line argument (like bean-query),
 | 
				
			||||||
 | 
					or you can write individual WHERE condition(s) as arguments. If you write
 | 
				
			||||||
 | 
					WHERE conditions, these options are used to build the rest of the query.
 | 
				
			||||||
""")
 | 
					""")
 | 
				
			||||||
    join_arg = cliutil.EnumArgument(JoinOperator)
 | 
					    join_arg = cliutil.EnumArgument(JoinOperator)
 | 
				
			||||||
    parser.add_argument(
 | 
					    query_group.add_argument(
 | 
				
			||||||
 | 
					        '--select', '-s',
 | 
				
			||||||
 | 
					        metavar='COLUMN',
 | 
				
			||||||
 | 
					        default=[],
 | 
				
			||||||
 | 
					        action=cliutil.ExtendAction,
 | 
				
			||||||
 | 
					        help="""Columns to select. You can write these as comma-separated
 | 
				
			||||||
 | 
					names, and/or specify the option more than once. You can specify both
 | 
				
			||||||
 | 
					bean-query's built-in column names (like `account` and `flag`) and metadata
 | 
				
			||||||
 | 
					keys.
 | 
				
			||||||
 | 
					""")
 | 
				
			||||||
 | 
					    query_group.add_argument(
 | 
				
			||||||
 | 
					        '--group-by', '-g',
 | 
				
			||||||
 | 
					        metavar='COLUMN',
 | 
				
			||||||
 | 
					        help="""Group output by this column
 | 
				
			||||||
 | 
					""")
 | 
				
			||||||
 | 
					    # query_group.add_argument(
 | 
				
			||||||
 | 
					    #     '--order-by', '--sort', '-r',
 | 
				
			||||||
 | 
					    #     metavar='COLUMN',
 | 
				
			||||||
 | 
					    #     help="""Order output by this column
 | 
				
			||||||
 | 
					    # """),
 | 
				
			||||||
 | 
					    query_group.add_argument(
 | 
				
			||||||
        '--join', '-j',
 | 
					        '--join', '-j',
 | 
				
			||||||
        metavar='OP',
 | 
					        metavar='OP',
 | 
				
			||||||
        type=join_arg.enum_type,
 | 
					        type=join_arg.enum_type,
 | 
				
			||||||
        default=JoinOperator.AND,
 | 
					        help=f"""Join your WHERE conditions with this operator.
 | 
				
			||||||
        help="""When you specify multiple WHERE conditions on the command line
 | 
					 | 
				
			||||||
and let query-report build the query, join conditions with this operator.
 | 
					 | 
				
			||||||
Choices are {join_arg.choices_str()}. Default 'and'.
 | 
					Choices are {join_arg.choices_str()}. Default 'and'.
 | 
				
			||||||
""")
 | 
					"""),
 | 
				
			||||||
    cliutil.add_loglevel_argument(parser)
 | 
					    query_group.add_argument(
 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        'query',
 | 
					        'query',
 | 
				
			||||||
        nargs=argparse.ZERO_OR_MORE,
 | 
					        nargs=argparse.ZERO_OR_MORE,
 | 
				
			||||||
        help="""Query to run non-interactively. You can specify a full query
 | 
					        help="""Full query or WHERE conditions to run non-interactively
 | 
				
			||||||
you write yourself, or conditions to follow WHERE and let query-report build
 | 
					 | 
				
			||||||
the rest of the query.
 | 
					 | 
				
			||||||
""")
 | 
					""")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args(arglist)
 | 
					    args = parser.parse_args(arglist)
 | 
				
			||||||
    return args
 | 
					    return args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -208,7 +253,11 @@ def main(arglist: Optional[Sequence[str]]=None,
 | 
				
			||||||
    fy = config.fiscal_year_begin()
 | 
					    fy = config.fiscal_year_begin()
 | 
				
			||||||
    if args.stop_date is None and args.start_date is not None:
 | 
					    if args.stop_date is None and args.start_date is not None:
 | 
				
			||||||
        args.stop_date = fy.next_fy_date(args.start_date)
 | 
					        args.stop_date = fy.next_fy_date(args.start_date)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
        query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
 | 
					        query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
 | 
				
			||||||
 | 
					    except ValueError as error:
 | 
				
			||||||
 | 
					        logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG))
 | 
				
			||||||
 | 
					        return 2
 | 
				
			||||||
    is_interactive = query is None and sys.stdin.isatty()
 | 
					    is_interactive = query is None and sys.stdin.isatty()
 | 
				
			||||||
    if args.report_type is None:
 | 
					    if args.report_type is None:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,8 +44,13 @@ def pipe_main(arglist, config):
 | 
				
			||||||
    returncode = qmod.main(arglist, stdout, stderr, config)
 | 
					    returncode = qmod.main(arglist, stdout, stderr, config)
 | 
				
			||||||
    return returncode, stdout, stderr
 | 
					    return returncode, stdout, stderr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def query_args(query=None, start_date=None, stop_date=None, join='AND'):
 | 
					def query_args(query=None, start_date=None, stop_date=None, join=None, select=None):
 | 
				
			||||||
 | 
					    if isinstance(join, str):
 | 
				
			||||||
        join = qmod.JoinOperator[join]
 | 
					        join = qmod.JoinOperator[join]
 | 
				
			||||||
 | 
					    if select is None:
 | 
				
			||||||
 | 
					        select = []
 | 
				
			||||||
 | 
					    elif isinstance(select, str):
 | 
				
			||||||
 | 
					        select = select.split(',')
 | 
				
			||||||
    return argparse.Namespace(**locals())
 | 
					    return argparse.Namespace(**locals())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_books_loader_empty():
 | 
					def test_books_loader_empty():
 | 
				
			||||||
| 
						 | 
					@ -99,6 +104,17 @@ def test_build_query_in_arglist(fy, query_str):
 | 
				
			||||||
    args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
 | 
					    args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
 | 
				
			||||||
    assert qmod.build_query(args, fy) == query_str
 | 
					    assert qmod.build_query(args, fy) == query_str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize('argname,argval', [
 | 
				
			||||||
 | 
					    ('join', qmod.JoinOperator.AND),
 | 
				
			||||||
 | 
					    ('join', qmod.JoinOperator.OR),
 | 
				
			||||||
 | 
					    ('select', 'date,flag'),
 | 
				
			||||||
 | 
					    ('select', 'position,rt-id'),
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval):
 | 
				
			||||||
 | 
					    args = query_args(['journal'], **{argname: argval})
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					        qmod.build_query(args, fy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
 | 
					@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
 | 
				
			||||||
def test_build_query_where_arglist_conditions(fy, count, join_op):
 | 
					def test_build_query_where_arglist_conditions(fy, count, join_op):
 | 
				
			||||||
    conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
 | 
					    conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
 | 
				
			||||||
| 
						 | 
					@ -108,6 +124,27 @@ def test_build_query_where_arglist_conditions(fy, count, join_op):
 | 
				
			||||||
    cond_index = query.index(' WHERE ') + 7
 | 
					    cond_index = query.index(' WHERE ') + 7
 | 
				
			||||||
    assert query[cond_index:] == '({})'.format(join_op.join(conds))
 | 
					    assert query[cond_index:] == '({})'.format(join_op.join(conds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize('select', [
 | 
				
			||||||
 | 
					    ['flag'],
 | 
				
			||||||
 | 
					    ['check'],
 | 
				
			||||||
 | 
					    ['flag', 'month'],
 | 
				
			||||||
 | 
					    ['cost_label', 'cost_metakey'],
 | 
				
			||||||
 | 
					    ['approval', 'receipt'],
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					def test_build_query_select_fields(fy, select):
 | 
				
			||||||
 | 
					    args = query_args(['year>2018'], select=list(select))
 | 
				
			||||||
 | 
					    query = qmod.build_query(args, fy)
 | 
				
			||||||
 | 
					    assert query.startswith('SELECT ')
 | 
				
			||||||
 | 
					    start_index = 7
 | 
				
			||||||
 | 
					    for field in select:
 | 
				
			||||||
 | 
					        if field != 'flag' and field != 'month' and field != 'cost_label':
 | 
				
			||||||
 | 
					            field = f'ANY_META("{field}") AS {field.replace("-", "_")}'
 | 
				
			||||||
 | 
					        match = re.search(rf',\s*{re.escape(field)}\b', query)
 | 
				
			||||||
 | 
					        assert match, f"field {field!r} not found in query: {query!r}"
 | 
				
			||||||
 | 
					        assert match.start() >= start_index
 | 
				
			||||||
 | 
					        start_index = match.end()
 | 
				
			||||||
 | 
					    assert query[start_index:start_index + 7] == ' WHERE '
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize('argname,date_arg', itertools.product(
 | 
					@pytest.mark.parametrize('argname,date_arg', itertools.product(
 | 
				
			||||||
    ['start_date', 'stop_date'],
 | 
					    ['start_date', 'stop_date'],
 | 
				
			||||||
    [testutil.FY_START_DATE, testutil.FY_START_DATE.year],
 | 
					    [testutil.FY_START_DATE, testutil.FY_START_DATE.year],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue