query: Add BQL functions for dealing with link metadata.

query-report was heading to a place where it was going to bifurcate. You
could structure input with its own special input switches, and with ODS
output, it would have its own dedicated grouping logic and use that.

But those things shouldn't be tied together for users. Instead, add
functions to BQL to be able to do the kind of grouping we want. This commit
adds those. Next we'll extend the ODS output to detect and format these
groups correctly.
This commit is contained in:
Brett Smith 2021-03-08 13:48:25 -05:00
parent 0f58960b67
commit 8af45e5f8a
2 changed files with 115 additions and 258 deletions

View file

@ -22,6 +22,7 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
List,
Mapping, Mapping,
NamedTuple, NamedTuple,
Optional, Optional,
@ -63,12 +64,27 @@ BUILTIN_FIELDS: AbstractSet[str] = frozenset(itertools.chain(
bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type] bc_query_env.TargetsEnvironment.functions, # type:ignore[has-type]
)) ))
PROGNAME = 'query-report' PROGNAME = 'query-report'
QUERY_PARSER = bc_query_parser.Parser()
logger = logging.getLogger('conservancy_beancount.reports.query') logger = logging.getLogger('conservancy_beancount.reports.query')
CellFunc = Callable[[Any], odf.table.TableCell] CellFunc = Callable[[Any], odf.table.TableCell]
EnvironmentFunctions = Dict[
# The real key type is something like:
# Union[str, Tuple[str, Type, ...]]
# but two issues with that. One, you can't use Ellipses in a Tuple like
# that, so there's no short way to declare this. Second, Beancount doesn't
# declare it anyway, and mypy infers it as Sequence[object]. So just use
# that.
Sequence[object],
Type[bc_query_compile.EvalFunction],
]
RowTypes = Sequence[Tuple[str, Type]] RowTypes = Sequence[Tuple[str, Type]]
Rows = Sequence[NamedTuple] Rows = Sequence[NamedTuple]
Store = List[Any]
QueryStatement = Union[
bc_query_parser.Balances,
bc_query_parser.Journal,
bc_query_parser.Select,
]
class BooksLoader: class BooksLoader:
"""Closure to load books with a zero-argument callable """Closure to load books with a zero-argument callable
@ -176,6 +192,73 @@ class QueryODS(core.BaseODS[NamedTuple, None]):
)) ))
# This class mostly supports type checking. Beancount code dynamically sets the
# ``store`` attribute, in bc_query_execute.execute_query().
class Context(bc_query_execute.RowContext):
store: Store
class MetaDocs(bc_query_env.AnyMeta):
"""Return a list of document links from metadata."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, list)
def __call__(self, context: Context) -> List[str]:
raw_value = super().__call__(context)
if isinstance(raw_value, str):
return raw_value.split()
else:
return []
class StrMeta(bc_query_env.AnyMeta):
"""Looks up metadata like AnyMeta, then always returns a string."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, str)
def __call__(self, context: Context) -> str:
raw_value = super().__call__(context)
if raw_value is None:
return ''
else:
return str(raw_value)
class AggregateSet(bc_query_compile.EvalAggregator):
__intypes__ = [object]
def __init__(self, operands: List[str]) -> None:
super().__init__(operands, set)
def allocate(self, allocator: bc_query_execute.Allocator) -> None:
self.handle = allocator.allocate()
def initialize(self, store: Store) -> None:
store[self.handle] = self.dtype()
def update(self, store: Store, context: Context) -> None:
value, = self.eval_args(context)
if isinstance(value, Sequence) and not isinstance(value, str):
store[self.handle].update(value)
else:
store[self.handle].add(value)
def __call__(self, context: Context) -> set:
return context.store[self.handle] # type:ignore[no-any-return]
class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment):
functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment]
functions['meta_docs'] = MetaDocs
functions['str_meta'] = StrMeta
class TargetsEnvironment(bc_query_env.TargetsEnvironment):
functions = FilterPostingsEnvironment.functions.copy()
functions.update(bc_query_env.AGGREGATOR_FUNCTIONS)
functions['set'] = AggregateSet
class BQLShell(bc_query_shell.BQLShell): class BQLShell(bc_query_shell.BQLShell):
def __init__( def __init__(
self, self,
@ -187,9 +270,11 @@ class BQLShell(bc_query_shell.BQLShell):
rt_wrapper: Optional[rtutil.RT]=None, rt_wrapper: Optional[rtutil.RT]=None,
) -> None: ) -> None:
super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify) super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify)
self.env_postings = FilterPostingsEnvironment()
self.env_targets = TargetsEnvironment()
self.ods = QueryODS(rt_wrapper) self.ods = QueryODS(rt_wrapper)
def on_Select(self, statement: str) -> None: def on_Select(self, statement: QueryStatement) -> None:
output_format: str = self.vars['format'] output_format: str = self.vars['format']
try: try:
render_func = getattr(self, f'_render_{output_format}') render_func = getattr(self, f'_render_{output_format}')
@ -206,7 +291,7 @@ class BQLShell(bc_query_shell.BQLShell):
row_types, rows = bc_query_execute.execute_query( row_types, rows = bc_query_execute.execute_query(
compiled_query, self.entries, self.options_map, compiled_query, self.entries, self.options_map,
) )
if self.vars['numberify'] and output_format != 'ods': if self.vars['numberify']:
logger.debug("numberifying query") logger.debug("numberifying query")
row_types, rows = bc_query_numberify.numberify_results( row_types, rows = bc_query_numberify.numberify_results(
row_types, rows, self.options_map['dcontext'].build(), row_types, rows, self.options_map['dcontext'].build(),
@ -251,14 +336,6 @@ class BQLShell(bc_query_shell.BQLShell):
) )
class JoinOperator(enum.Enum):
AND = 'AND'
OR = 'OR'
def join(self, parts: Iterable[str]) -> str:
return f' {self.value} '.join(parts)
class ReportFormat(enum.Enum): class ReportFormat(enum.Enum):
TEXT = 'text' TEXT = 'text'
TXT = TEXT TXT = TEXT
@ -266,54 +343,6 @@ class ReportFormat(enum.Enum):
ODS = 'ods' ODS = 'ods'
def _date_condition(
date: Union[int, datetime.date],
year_to_date: Callable[[int], datetime.date],
op: str,
) -> str:
if isinstance(date, int):
date = year_to_date(date)
return f'date {op} {date.isoformat()}'
def build_query(
args: argparse.Namespace,
fy: books.FiscalYear,
in_file: Optional[TextIO]=None,
) -> Optional[str]:
if not args.query:
args.query = [] if in_file is None else [line[:-1] for line in in_file]
plain_query = ' '.join(args.query)
if not plain_query or plain_query.isspace():
return None
try:
QUERY_PARSER.parse(plain_query)
except bc_query_parser.ParseError:
if args.join is None:
args.join = JoinOperator.AND
select = [
'date',
'ANY_META("entity") AS entity',
'narration AS description',
'COST(position) AS booked_amount',
*(f'ANY_META("{field}") AS {field.replace("-", "_")}'
if field not in BUILTIN_FIELDS
and re.fullmatch(r'[a-z][-_A-Za-z0-9]*', field)
else field
for field in args.select),
]
conds = [f'({args.join.join(args.query)})']
if args.start_date is not None:
conds.append(_date_condition(args.start_date, fy.first_date, '>='))
if args.stop_date is not None:
conds.append(_date_condition(args.stop_date, fy.next_fy_date, '<'))
return f'SELECT {", ".join(select)} WHERE {" AND ".join(conds)}'
else:
if args.join:
raise ValueError("cannot specify --join with a full query")
if args.select:
raise ValueError("cannot specify --select with a full query")
return plain_query
def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace: def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog=PROGNAME) parser = argparse.ArgumentParser(prog=PROGNAME)
cliutil.add_version_argument(parser) cliutil.add_version_argument(parser)
@ -321,20 +350,18 @@ def parse_arguments(arglist: Optional[Sequence[str]]=None) -> argparse.Namespace
parser.add_argument( parser.add_argument(
'--begin', '--start', '-b', '--begin', '--start', '-b',
dest='start_date', dest='start_date',
metavar='DATE', metavar='YEAR',
type=cliutil.year_or_date_arg, type=cliutil.year_or_date_arg,
help="""Begin loading entries from this fiscal year. When query-report help="""Begin loading entries from this fiscal year. You can specify a
builds the query, it will include a condition `date >= DATE`. full date, and %(prog)s will use the fiscal year for that date.
""") """)
parser.add_argument( parser.add_argument(
'--end', '--stop', '-e', '--end', '--stop', '-e',
dest='stop_date', dest='stop_date',
metavar='DATE', metavar='YEAR',
type=cliutil.year_or_date_arg, type=cliutil.year_or_date_arg,
help="""End loading entries from this fiscal year. When query-report help="""End loading entries at this fiscal year. You can specify a
builds the query, it will include a condition `date < DATE`. If you specify a full date, and %(prog)s will use the fiscal year for that date.
begin date but not an end date, the default end date will be the end of the
fiscal year of the begin date.
""") """)
cliutil.add_rewrite_rules_argument(parser) cliutil.add_rewrite_rules_argument(parser)
format_arg = cliutil.EnumArgument(ReportFormat) format_arg = cliutil.EnumArgument(ReportFormat)
@ -344,7 +371,8 @@ fiscal year of the begin date.
type=format_arg.enum_type, type=format_arg.enum_type,
help="""Format of report to generate. Choices are help="""Format of report to generate. Choices are
{format_arg.choices_str()}. Default is guessed from your output filename {format_arg.choices_str()}. Default is guessed from your output filename
extension, or 'text' if that fails. extension. If that fails, default is 'text' for interactive output, and 'ods'
otherwise.
""") """)
parser.add_argument( parser.add_argument(
'--numberify', '-m', '--numberify', '-m',
@ -359,44 +387,12 @@ extension, or 'text' if that fails.
The default is stdout for text and CSV reports, and a generated filename for The default is stdout for text and CSV reports, and a generated filename for
ODS reports. ODS reports.
""") """)
parser.add_argument(
query_group = parser.add_argument_group("query options", """
You can write a single full query as a command line argument (like bean-query),
or you can write individual WHERE condition(s) as arguments. If you write
WHERE conditions, these options are used to build the rest of the query.
""")
join_arg = cliutil.EnumArgument(JoinOperator)
query_group.add_argument(
'--select', '-s',
metavar='COLUMN',
default=[],
action=cliutil.ExtendAction,
help="""Columns to select. You can write these as comma-separated
names, and/or specify the option more than once. You can specify both
bean-query's built-in column names (like `account` and `flag`) and metadata
keys.
""")
query_group.add_argument(
'--group-by', '-g',
metavar='COLUMN',
help="""Group output by this column
""")
# query_group.add_argument(
# '--order-by', '--sort', '-r',
# metavar='COLUMN',
# help="""Order output by this column
# """),
query_group.add_argument(
'--join', '-j',
metavar='OP',
type=join_arg.enum_type,
help=f"""Join your WHERE conditions with this operator.
Choices are {join_arg.choices_str()}. Default 'and'.
"""),
query_group.add_argument(
'query', 'query',
nargs=argparse.ZERO_OR_MORE, nargs=argparse.ZERO_OR_MORE,
help="""Full query or WHERE conditions to run non-interactively default=[],
help="""Query to run non-interactively. If none is provided, and
standard input is not a terminal, reads the query from stdin instead.
""") """)
args = parser.parse_args(arglist) args = parser.parse_args(arglist)
@ -413,20 +409,14 @@ def main(arglist: Optional[Sequence[str]]=None,
config = configmod.Config() config = configmod.Config()
config.load_file() config.load_file()
fy = config.fiscal_year_begin() query = ' '.join(args.query).strip()
if args.stop_date is None and args.start_date is not None: if not query and not sys.stdin.isatty():
args.stop_date = fy.next_fy_date(args.start_date) query = sys.stdin.read().strip()
try:
query = build_query(args, fy, None if sys.stdin.isatty() else sys.stdin)
except ValueError as error:
logger.error(error.args[0], exc_info=logger.isEnabledFor(logging.DEBUG))
return 2
is_interactive = query is None and sys.stdin.isatty()
if args.report_type is None: if args.report_type is None:
try: try:
args.report_type = ReportFormat[args.output_file.suffix[1:].upper()] args.report_type = ReportFormat[args.output_file.suffix[1:].upper()]
except (AttributeError, KeyError): except (AttributeError, KeyError):
args.report_type = ReportFormat.TEXT if is_interactive else ReportFormat.ODS args.report_type = ReportFormat.ODS if query else ReportFormat.TEXT
load_func = BooksLoader( load_func = BooksLoader(
config.books_loader(), config.books_loader(),
@ -435,7 +425,7 @@ def main(arglist: Optional[Sequence[str]]=None,
[rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules], [rewrite.RewriteRuleset.from_yaml(path) for path in args.rewrite_rules],
) )
shell = BQLShell( shell = BQLShell(
is_interactive, not query,
load_func, load_func,
stdout, stdout,
args.report_type.value, args.report_type.value,
@ -443,10 +433,10 @@ def main(arglist: Optional[Sequence[str]]=None,
config.rt_wrapper(), config.rt_wrapper(),
) )
shell.on_Reload() shell.on_Reload()
if query is None: if query:
shell.cmdloop()
else:
shell.onecmd(query) shell.onecmd(query)
else:
shell.cmdloop()
if not shell.ods.is_empty(): if not shell.ods.is_empty():
shell.ods.set_common_properties(config.books_repo()) shell.ods.set_common_properties(config.books_repo())
@ -454,7 +444,7 @@ def main(arglist: Optional[Sequence[str]]=None,
if args.output_file is None: if args.output_file is None:
out_dir_path = config.repository_path() or Path() out_dir_path = config.repository_path() or Path()
args.output_file = out_dir_path / 'QueryResults_{}.ods'.format( args.output_file = out_dir_path / 'QueryResults_{}.ods'.format(
datetime.datetime.now().isoformat(timespec='minutes'), datetime.datetime.now().isoformat(timespec='seconds'),
) )
logger.info("Writing spreadsheet to %s", args.output_file) logger.info("Writing spreadsheet to %s", args.output_file)
ods_file = cliutil.bytes_output(args.output_file, stdout) ods_file = cliutil.bytes_output(args.output_file, stdout)

View file

@ -38,10 +38,6 @@ class MockRewriteRuleset:
yield post._replace(units=testutil.Amount(number, currency)) yield post._replace(units=testutil.Amount(number, currency))
@pytest.fixture(scope='module')
def fy():
return FiscalYear(3, 1)
@pytest.fixture(scope='module') @pytest.fixture(scope='module')
def rt(): def rt():
return rtutil.RT(testutil.RTClient()) return rtutil.RT(testutil.RTClient())
@ -52,15 +48,6 @@ def pipe_main(arglist, config, stdout_type=io.StringIO):
returncode = qmod.main(arglist, stdout, stderr, config) returncode = qmod.main(arglist, stdout, stderr, config)
return returncode, stdout, stderr return returncode, stdout, stderr
def query_args(query=None, start_date=None, stop_date=None, join=None, select=None):
if isinstance(join, str):
join = qmod.JoinOperator[join]
if select is None:
select = []
elif isinstance(select, str):
select = select.split(',')
return argparse.Namespace(**locals())
def test_books_loader_empty(): def test_books_loader_empty():
result = qmod.BooksLoader(None)() result = qmod.BooksLoader(None)()
assert not result.entries assert not result.entries
@ -90,130 +77,6 @@ def test_books_loader_rewrites():
assert numbers assert numbers
assert all(abs(number) >= 40 for number in numbers) assert all(abs(number) >= 40 for number in numbers)
@pytest.mark.parametrize('file_s', [None, '', ' \n \n\n'])
def test_build_query_empty(fy, file_s):
args = query_args()
if file_s is None:
query = qmod.build_query(args, fy)
else:
with io.StringIO(file_s) as qfile:
query = qmod.build_query(args, fy, qfile)
assert query is None
@pytest.mark.parametrize('query_str', [
'SELECT * WHERE date >= 2018-03-01',
'select *',
'JOURNAL "Income:Donations"',
'journal',
'BALANCES FROM year=2018',
'balances',
])
def test_build_query_in_arglist(fy, query_str):
args = query_args(query_str.split(), testutil.PAST_DATE, testutil.FUTURE_DATE)
assert qmod.build_query(args, fy) == query_str
@pytest.mark.parametrize('argname,argval', [
('join', qmod.JoinOperator.AND),
('join', qmod.JoinOperator.OR),
('select', 'date,flag'),
('select', 'position,rt-id'),
])
def test_build_query_cant_mix_switches_with_full_query(fy, argname, argval):
args = query_args(['journal'], **{argname: argval})
with pytest.raises(ValueError):
qmod.build_query(args, fy)
@pytest.mark.parametrize('count,join_op', enumerate(qmod.JoinOperator, 1))
def test_build_query_where_arglist_conditions(fy, count, join_op):
conds = ['account ~ "^Income:"', 'year >= 2018'][:count]
args = query_args(conds, join=join_op.name)
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
cond_index = query.index(' WHERE ') + 7
assert query[cond_index:] == '({})'.format(join_op.join(conds))
@pytest.mark.parametrize('select', [
['flag'],
['check'],
['flag', 'month'],
['cost_label', 'cost_metakey'],
['approval', 'receipt'],
])
def test_build_query_select_fields(fy, select):
args = query_args(['year>2018'], select=list(select))
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
start_index = 7
for field in select:
if field != 'flag' and field != 'month' and field != 'cost_label':
field = f'ANY_META("{field}") AS {field.replace("-", "_")}'
match = re.search(rf',\s*{re.escape(field)}\b', query)
assert match, f"field {field!r} not found in query: {query!r}"
assert match.start() >= start_index
start_index = match.end()
assert query[start_index:start_index + 7] == ' WHERE '
@pytest.mark.parametrize('argname,date_arg', itertools.product(
['start_date', 'stop_date'],
[testutil.FY_START_DATE, testutil.FY_START_DATE.year],
))
def test_build_query_one_date_arg(fy, argname, date_arg):
query_kwargs = {
argname: date_arg,
'query': ['flag = "*"', 'flag = "!"'],
'join': 'OR',
}
args = query_args(**query_kwargs)
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
cond_index = query.index(' WHERE ') + 7
if argname == 'start_date':
expect_op = '>='
year_to_date = fy.first_date
else:
expect_op = '<'
year_to_date = fy.next_fy_date
if not isinstance(date_arg, datetime.date):
date_arg = year_to_date(date_arg)
assert query[cond_index:] == '({}) AND date {} {}'.format(
' OR '.join(query_kwargs['query']), expect_op, date_arg.isoformat(),
)
@pytest.mark.parametrize('start_date,stop_date', itertools.product(
[testutil.PAST_DATE, testutil.PAST_DATE.year],
[testutil.FUTURE_DATE, testutil.FUTURE_DATE.year],
))
def test_build_query_two_date_args(fy, start_date, stop_date):
args = query_args(['account ~ "^Equity:"'], start_date, stop_date, 'AND')
query = qmod.build_query(args, fy)
assert query.startswith('SELECT ')
cond_index = query.index(' WHERE ') + 7
if isinstance(start_date, int):
start_date = fy.first_date(start_date)
if isinstance(stop_date, int):
stop_date = fy.next_fy_date(stop_date)
assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format(
args.query[0], start_date.isoformat(), stop_date.isoformat(),
)
def test_build_query_plain_from_file(fy):
with io.StringIO("SELECT *\n WHERE account ~ '^Assets:';\n") as qfile:
query = qmod.build_query(query_args(), fy, qfile)
assert re.fullmatch(r"SELECT \*\s+WHERE account ~ '\^Assets:';\s*", query)
def test_build_query_from_file_where_clauses(fy):
conds = ["account ~ '^Income:'", "account ~ '^Expenses:'"]
args = query_args(None, testutil.PAST_DATE, testutil.FUTURE_DATE, 'OR')
with io.StringIO(''.join(f'{s}\n' for s in conds)) as qfile:
query = qmod.build_query(args, fy, qfile)
assert query.startswith('SELECT ')
cond_index = query.index(' WHERE ') + 7
assert query[cond_index:] == '({}) AND date >= {} AND date < {}'.format(
' OR '.join(conds),
testutil.PAST_DATE.isoformat(),
testutil.FUTURE_DATE.isoformat(),
)
@pytest.mark.parametrize('arglist,fy', testutil.combine_values( @pytest.mark.parametrize('arglist,fy', testutil.combine_values(
[['--report-type', 'text'], ['--format=text'], ['-f', 'txt']], [['--report-type', 'text'], ['--format=text'], ['-f', 'txt']],
range(2018, 2021), range(2018, 2021),
@ -362,7 +225,11 @@ def test_ods_is_empty():
def test_ods_output(fy, account, amt_prefix): def test_ods_output(fy, account, amt_prefix):
books_path = testutil.test_path(f'books/books/{fy}.beancount') books_path = testutil.test_path(f'books/books/{fy}.beancount')
config = testutil.TestConfig(books_path=books_path) config = testutil.TestConfig(books_path=books_path)
arglist = ['-O', '-', '-f', 'ods', f'account ~ "^{account}:"'] arglist = [
'-O', '-',
'-f', 'ods',
f'SELECT date, narration, UNITS(position) WHERE account ~ "^{account}:"',
]
returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO) returncode, stdout, stderr = pipe_main(arglist, config, io.BytesIO)
assert returncode == 0 assert returncode == 0
stdout.seek(0) stdout.seek(0)
@ -371,7 +238,7 @@ def test_ods_output(fy, account, amt_prefix):
next(rows) # Skip header row next(rows) # Skip header row
amt_pattern = rf'^{re.escape(amt_prefix)}\d' amt_pattern = rf'^{re.escape(amt_prefix)}\d'
for count, row in enumerate(rows, 1): for count, row in enumerate(rows, 1):
date, entity, narration, amount = row.childNodes date, narration, amount = row.childNodes
assert re.fullmatch(rf'{fy}-\d{{2}}-\d{{2}}', date.text) assert re.fullmatch(rf'{fy}-\d{{2}}-\d{{2}}', date.text)
assert narration.text.startswith(f'{fy} ') assert narration.text.startswith(f'{fy} ')
assert re.match(amt_pattern, amount.text) assert re.match(amt_pattern, amount.text)