diff --git a/conservancy_beancount/reports/query.py b/conservancy_beancount/reports/query.py index d6b0062..77a5222 100644 --- a/conservancy_beancount/reports/query.py +++ b/conservancy_beancount/reports/query.py @@ -22,6 +22,7 @@ from typing import ( Dict, Iterable, Iterator, + List, Mapping, NamedTuple, Optional, @@ -63,12 +64,27 @@ BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain( bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type] )) PROGNAME = 'query-report' -QUERY_PARSER = bc_query_parser.Parser() logger = logging.getLogger('conservancy_beancount.reports.query') CellFunc = Callable[[Any], odf.table.TableCell] +EnvironmentFunctions = Dict[ + # The real key type is something like: + # Union[str, Tuple[str, Type, ...]] + # but two issues with that. One, you can't use Ellipses in a Tuple like + # that, so there's no short way to declare this. Second, Beancount doesn't + # declare it anyway, and mypy infers it as Sequence[object]. So just use + # that. + Sequence[object], + Type[bc_query_compile.EvalFunction], +] RowTypes = Sequence[Tuple[str, Type]] Rows = Sequence[NamedTuple] +Store = List[Any] +QueryStatement = Union[ + bc_query_parser.Balances, + bc_query_parser.Journal, + bc_query_parser.Select, +] class BooksLoader: """Closure to load books with a zero-argument callable @@ -176,6 +192,73 @@ class QueryODS(core.BaseODS[NamedTuple, None]): )) +# This class mostly supports type checking. Beancount code dynamically sets the +# ``store`` attribute, in bc_query_execute.execute_query(). +class Context(bc_query_execute.RowContext): + store: Store + + +class MetaDocs(bc_query_env.AnyMeta): + """Return a list of document links from metadata.""" + def __init__(self, operands: List[str]) -> None: + super(bc_query_env.AnyMeta, self).__init__(operands, list) + + def __call__(self, context: Context) -> List[str]: + raw_value = super().__call__(context) + if isinstance(raw_value, str): + return raw_value.split() + else: + return [] + + +class StrMeta(bc_query_env.AnyMeta): + """Looks up metadata like AnyMeta, then always returns a string.""" + def __init__(self, operands: List[str]) -> None: + super(bc_query_env.AnyMeta, self).__init__(operands, str) + + def __call__(self, context: Context) -> str: + raw_value = super().__call__(context) + if raw_value is None: + return '' + else: + return str(raw_value) + + +class AggregateSet(bc_query_compile.EvalAggregator): + __intypes__ = [object] + + def __init__(self, operands: List[str]) -> None: + super().__init__(operands, set) + + def allocate(self, allocator: bc_query_execute.Allocator) -> None: + self.handle = allocator.allocate() + + def initialize(self, store: Store) -> None: + store[self.handle] = self.dtype() + + def update(self, store: Store, context: Context) -> None: + value, = self.eval_args(context) + if isinstance(value, Sequence) and not isinstance(value, str): + store[self.handle].update(value) + else: + store[self.handle].add(value) + + def __call__(self, context: Context) -> set: + return context.store[self.handle] # type:ignore[no-any-return] + + +class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment): + functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment] + functions['meta_docs'] = MetaDocs + functions['str_meta'] = StrMeta + + +class TargetsEnvironment(bc_query_env.TargetsEnvironment): + functions = FilterPostingsEnvironment.functions.copy() + functions.update(bc_query_env.AGGREGATOR_FUNCTIONS) + functions['set'] = AggregateSet + + class BQLShell(bc_query_shell.BQLShell): def __init__( self, @@ -187,9 +270,11 @@ class BQLShell(bc_query_shell.BQLShell): rt_wrapper: Optional[rtutil.RT]=None, ) -> None: super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify) + self.env_postings = FilterPostingsEnvironment() + self.env_targets = TargetsEnvironment() self.ods = QueryODS(rt_wrapper) - def on_Select(self, statement: str) -> None: + def on_Select(self, statement: QueryStatement) -> None: output_format: str = self.vars['format'] try: render_func = getattr(self, f'_render_{output_format}') @@ -206,7 +291,7 @@ class BQLShell(bc_query_shell.BQLShell): row_types, rows = bc_query_execute.execute_query( compiled_query, self.entries, self.options_map, ) - if self.vars['numberify'] and output_format != 'ods': + if self.vars['numberify']: logger.debug("numberifying query") row_types, rows = bc_query_numberify.numberify_results( row_types, rows, self.options_map['dcontext'].build(), @@ -251,14 +336,6 @@ class BQLShell(bc_query_shell.BQLShell): ) -class JoinOperator(enum.Enum): - AND = 'AND' - OR = 'OR' - - def join(self, parts: Iterable[str]) -> str: - return f' {self.value} '.join(parts) - - class ReportFormat(enum.Enum): TEXT = 'text' TXT = TEXT @@ -266,54 +343,6 @@ class ReportFormat(enum.Enum): ODS = 'ods' -def _date_condition( - date: Union[int, datetime.date], - year_to_date: Callable[[int], datetime.date], - op: str, -) -> str: - if isinstance(date, int): - date = year_to_date(date) - return f'date {op} {date.isoformat()}' - -def build_query( - args: argparse.Namespace, - fy: books.FiscalYear, - in_file: Optional[TextIO]=None, -) -> Optional[str]: - if not args.query: - args.query = [] if in_file is None else [line[:-1] for line in in_file] - plain_query = ' '.join(args.query) - if not plain_query or plain_query.isspace(): - return None - try: - QUERY_PARSER.parse(plain_query) - except bc_query_parser.ParseError: - if args.join is None: - args.join = JoinOperator.AND - select = [ - 'date', - 'ANY_META("entity") AS entity', - 'narration AS description', - 'COST(position) AS booked_amount', - *(f'ANY_META("{field}") AS {field.replace("-", "_")}' - if field not in BUILTIN_FIELDS - and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field) - else field - for field in args.select), - ] - conds = [f'({args.join.join(args.query)})'] - if args.start_date is not None: - conds.append(_date_condition(args.start_date, fy.first_date, '>=')) - if args.stop_date is not None: - conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<')) - return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}' - else: - if args.join: - raise ValueError("cannot specify --join with a full query") - if args.select: - raise ValueError("cannot specify --select with a full query") - return plain_query - def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: parser = argparse.ArgumentParser(prog=PROGNAME) cliutil.add_version_argument(parser) @@ -321,20 +350,18 @@ def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace parser.add_argument( '--begin', '--start', '-b', dest='start_date', - metavar='DATE', + metavar='YEAR', type=cliutil.year_or_date_arg, - help="""Begin loading entries from this fiscal year. When query-report -builds the query, it will include a condition `date >= DATE`. + help="""Begin loading entries from this fiscal year. You can specify a +full date, and %(prog)s will use the fiscal year for that date. """) parser.add_argument( '--end', '--stop', '-e', dest='stop_date', - metavar='DATE', + metavar='YEAR', type=cliutil.year_or_date_arg, - help="""End loading entries from this fiscal year. When query-report -builds the query, it will include a condition `date < DATE`. If you specify a -begin date but not an end date, the default end date will be the end of the -fiscal year of the begin date. + help="""End loading entries at this fiscal year. You can specify a +full date, and %(prog)s will use the fiscal year for that date. """) cliutil.add_rewrite_rules_argument(parser) format_arg = cliutil.EnumArgument(ReportFormat) @@ -344,7 +371,8 @@ fiscal year of the begin date. type=format_arg.enum_type, help="""Format of report to generate. Choices are {format_arg.choices_str()}. Default is guessed from your output filename -extension, or 'text' if that fails. +extension. If that fails, default is 'text' for interactive output, and 'ods' +otherwise. """) parser.add_argument( '--numberify', '-m', @@ -359,44 +387,12 @@ extension, or 'text' if that fails. The default is stdout for text and CSV reports, and a generated filename for ODS reports. """) - - query_group = parser.add_argument_group("query options", """ -You can write a single full query as a command line argument (like bean-query), -or you can write individual WHERE condition(s) as arguments. If you write -WHERE conditions, these options are used to build the rest of the query. -""") - join_arg = cliutil.EnumArgument(JoinOperator) - query_group.add_argument( - '--select', '-s', - metavar='COLUMN', - default=[], - action=cliutil.ExtendAction, - help="""Columns to select. You can write these as comma-separated -names, and/or specify the option more than once. You can specify both -bean-query's built-in column names (like `account` and `flag`) and metadata -keys. -""") - query_group.add_argument( - '--group-by', '-g', - metavar='COLUMN', - help="""Group output by this column -""") - # query_group.add_argument( - # '--order-by', '--sort', '-r', - # metavar='COLUMN', - # help="""Order output by this column - # """), - query_group.add_argument( - '--join', '-j', - metavar='OP', - type=join_arg.enum_type, - help=f"""Join your WHERE conditions with this operator. -Choices are {join_arg.choices_str()}. Default 'and'. -"""), - query_group.add_argument( + parser.add_argument( 'query', nargs=argparse.ZERO_OR_MORE, - help="""Full query or WHERE conditions to run non-interactively + default=[], + help="""Query to run non-interactively. If none is provided, and +standard input is not a terminal, reads the query from stdin instead. """) args = parser.parse_args(arglist) @@ -413,20 +409,14 @@ def main(arglist: Optional[Sequence[str]]=None, config = configmod.Config() config.load_file() - fy = config.fiscal_year_begin() - if args.stop_date is None and args.start_date is not None: - args.stop_date = fy.next_fy_date(args.start_date) - try: - query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin) - except ValueError as error: - logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG)) - return 2 - is_interactive = query is None and sys.stdin.isatty() + query = ' '.join(args.query).strip() + if not query and not sys.stdin.isatty(): + query = sys.stdin.read().strip() if args.report_type is None: try: args.report_type = ReportFormat[args.output_file.suffix[1:].upper()] except (AttributeError, KeyError): - args.report_type = ReportFormat.TEXT if is_interactive else ReportFormat.ODS + args.report_type = ReportFormat.ODS if query else ReportFormat.TEXT load_func = BooksLoader( config.books_loader(), @@ -435,7 +425,7 @@ def main(arglist: Optional[Sequence[str]]=None, [rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules], ) shell = BQLShell( - is_interactive, + not query, load_func, stdout, args.report_type.value, @@ -443,10 +433,10 @@ def main(arglist: Optional[Sequence[str]]=None, config.rt_wrapper(), ) shell.on_Reload() - if query is None: - shell.cmdloop() - else: + if query: shell.onecmd(query) + else: + shell.cmdloop() if not shell.ods.is_empty(): shell.ods.set_common_properties(config.books_repo()) @@ -454,7 +444,7 @@ def main(arglist: Optional[Sequence[str]]=None, if args.output_file is None: out_dir_path = config.repository_path() or Path() args.output_file = out_dir_path / 'QueryResults_{}.ods'.format( - datetime.datetime.now().isoformat(timespec='minutes'), + datetime.datetime.now().isoformat(timespec='seconds'), ) logger.info("Writing spreadsheet to %s", args.output_file) ods_file = cliutil.bytes_output(args.output_file, stdout) diff --git a/tests/test_reports_query.py b/tests/test_reports_query.py index ed2ba64..492069b 100644 --- a/tests/test_reports_query.py +++ b/tests/test_reports_query.py @@ -38,10 +38,6 @@ class MockRewriteRuleset: yield post._replace(units=testutil.Amount(number, currency)) -@pytest.fixture(scope='module') -def fy(): - return FiscalYear(3, 1) - @pytest.fixture(scope='module') def rt(): return rtutil.RT(testutil.RTClient()) @@ -52,15 +48,6 @@ def pipe_main(arglist, config, stdout_type=io.StringIO): returncode = qmod.main(arglist, stdout, stderr, config) return returncode, stdout, stderr -def query_args(query=None, start_date=None, stop_date=None, join=None, select=None): - if isinstance(join, str): - join = qmod.JoinOperator[join] - if select is None: - select = [] - elif isinstance(select, str): - select = select.split(',') - return argparse.Namespace(**locals()) - def test_books_loader_empty(): result = qmod.BooksLoader(None)() assert not result.entries @@ -90,130 +77,6 @@ def test_books_loader_rewrites(): assert numbers assert all(abs(number) >= 40 for number in numbers) -@pytest.mark.parametrize('file_s', [None, '', ' \n \n\n']) -def test_build_query_empty(fy, file_s): - args = query_args() - if file_s is None: - query = qmod.build_query(args, fy) - else: - with io.StringIO(file_s) as qfile: - query = qmod.build_query(args, fy, qfile) - assert query is None - -@pytest.mark.parametrize('query_str', [ - 'SELECT * WHERE date >= 2018-03-01', - 'select *', - 'JOURNAL "Income:Donations"', - 'journal', - 'BALANCES FROM year=2018', - 'balances', -]) -def test_build_query_in_arglist(fy, query_str): - args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE) - assert qmod.build_query(args, fy) == query_str - -@pytest.mark.parametrize('argname,argval', [ - ('join', qmod.JoinOperator.AND), - ('join', qmod.JoinOperator.OR), - ('select', 'date,flag'), - ('select', 'position,rt-id'), -]) -def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval): - args = query_args(['journal'], **{argname: argval}) - with pytest.raises(ValueError): - qmod.build_query(args, fy) - -@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1)) -def test_build_query_where_arglist_conditions(fy, count, join_op): - conds = ['account ~ "^Income:"', 'year >= 2018'][:count] - args = query_args(conds, join=join_op.name) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - assert query[cond_index:] == '({})'.format(join_op.join(conds)) - -@pytest.mark.parametrize('select', [ - ['flag'], - ['check'], - ['flag', 'month'], - ['cost_label', 'cost_metakey'], - ['approval', 'receipt'], -]) -def test_build_query_select_fields(fy, select): - args = query_args(['year>2018'], select=list(select)) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - start_index = 7 - for field in select: - if field != 'flag' and field != 'month' and field != 'cost_label': - field = f'ANY_META("{field}") AS {field.replace("-", "_")}' - match = re.search(rf',\s*{re.escape(field)}\b', query) - assert match, f"field {field!r} not found in query: {query!r}" - assert match.start() >= start_index - start_index = match.end() - assert query[start_index:start_index + 7] == ' WHERE ' - -@pytest.mark.parametrize('argname,date_arg', itertools.product( - ['start_date', 'stop_date'], - [testutil.FY_START_DATE, testutil.FY_START_DATE.year], -)) -def test_build_query_one_date_arg(fy, argname, date_arg): - query_kwargs = { - argname: date_arg, - 'query': ['flag = "*"', 'flag = "!"'], - 'join': 'OR', - } - args = query_args(**query_kwargs) - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - if argname == 'start_date': - expect_op = '>=' - year_to_date = fy.first_date - else: - expect_op = '<' - year_to_date = fy.next_fy_date - if not isinstance(date_arg, datetime.date): - date_arg = year_to_date(date_arg) - assert query[cond_index:] == '({}) AND date {} {}'.format( - ' OR '.join(query_kwargs['query']), expect_op, date_arg.isoformat(), - ) - -@pytest.mark.parametrize('start_date,stop_date', itertools.product( - [testutil.PAST_DATE, testutil.PAST_DATE.year], - [testutil.FUTURE_DATE, testutil.FUTURE_DATE.year], -)) -def test_build_query_two_date_args(fy, start_date, stop_date): - args = query_args(['account ~ "^Equity:"'], start_date, stop_date, 'AND') - query = qmod.build_query(args, fy) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - if isinstance(start_date, int): - start_date = fy.first_date(start_date) - if isinstance(stop_date, int): - stop_date = fy.next_fy_date(stop_date) - assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format( - args.query[0], start_date.isoformat(), stop_date.isoformat(), - ) - -def test_build_query_plain_from_file(fy): - with io.StringIO("SELECT *\n WHERE account ~ '^Assets:';\n") as qfile: - query = qmod.build_query(query_args(), fy, qfile) - assert re.fullmatch(r"SELECT \*\s+WHERE account ~ '\^Assets:';\s*", query) - -def test_build_query_from_file_where_clauses(fy): - conds = ["account ~ '^Income:'", "account ~ '^Expenses:'"] - args = query_args(None, testutil.PAST_DATE, testutil.FUTURE_DATE, 'OR') - with io.StringIO(''.join(f'{s}\n' for s in conds)) as qfile: - query = qmod.build_query(args, fy, qfile) - assert query.startswith('SELECT ') - cond_index = query.index(' WHERE ') + 7 - assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format( - ' OR '.join(conds), - testutil.PAST_DATE.isoformat(), - testutil.FUTURE_DATE.isoformat(), - ) - @pytest.mark.parametrize('arglist,fy', testutil.combine_values( [['--report-type', 'text'], ['--format=text'], ['-f', 'txt']], range(2018, 2021), @@ -362,7 +225,11 @@ def test_ods_is_empty(): def test_ods_output(fy, account, amt_prefix): books_path = testutil.test_path(f'books/books/{fy}.beancount') config = testutil.TestConfig(books_path=books_path) - arglist = ['-O', '-', '-f', 'ods', f'account ~ "^{account}:"'] + arglist = [ + '-O', '-', + '-f', 'ods', + f'SELECT date, narration, UNITS(position) WHERE account ~ "^{account}:"', + ] returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO) assert returncode == 0 stdout.seek(0) @@ -371,7 +238,7 @@ def test_ods_output(fy, account, amt_prefix): next(rows) # Skip header row amt_pattern = rf'^{re.escape(amt_prefix)}\d' for count, row in enumerate(rows, 1): - date, entity, narration, amount = row.childNodes + date, narration, amount = row.childNodes assert re.fullmatch(rf'{fy}-\d{{2}}-\d{{2}}', date.text) assert narration.text.startswith(f'{fy} ') assert re.match(amt_pattern, amount.text)