query: Override BQLShell.on_Select.

This is enough to continue passing the existing tests, while giving us a
clear hook to develop ODS output.
This commit is contained in:
Brett Smith 2021-03-05 17:36:19 -05:00
parent 221d42a479
commit 9c943bc8a9

View file

@ -6,6 +6,7 @@
# LICENSE.txt in the repository. # LICENSE.txt in the repository.
import argparse import argparse
import contextlib
import datetime import datetime
import enum import enum
import itertools import itertools
@ -21,10 +22,12 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
Mapping, Mapping,
NamedTuple,
Optional, Optional,
Sequence, Sequence,
TextIO, TextIO,
Tuple, Tuple,
Type,
Union, Union,
) )
from ..beancount_types import ( from ..beancount_types import (
@ -36,9 +39,13 @@ from ..beancount_types import (
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
import beancount.query.numberify as bc_query_numberify
import beancount.query.query_compile as bc_query_compile
import beancount.query.query_env as bc_query_env import beancount.query.query_env as bc_query_env
import beancount.query.query_execute as bc_query_execute
import beancount.query.query_parser as bc_query_parser import beancount.query.query_parser as bc_query_parser
import beancount.query.shell as bc_query import beancount.query.query_render as bc_query_render
import beancount.query.shell as bc_query_shell
from . import core from . import core
from . import rewrite from . import rewrite
@ -55,6 +62,9 @@ PROGNAME = 'query-report'
QUERY_PARSER = bc_query_parser.Parser() QUERY_PARSER = bc_query_parser.Parser()
logger = logging.getLogger('conservancy_beancount.reports.query') logger = logging.getLogger('conservancy_beancount.reports.query')
RowTypes = Sequence[Tuple[str, Type]]
Rows = Sequence[NamedTuple]
class BooksLoader: class BooksLoader:
"""Closure to load books with a zero-argument callable """Closure to load books with a zero-argument callable
@ -88,8 +98,63 @@ class BooksLoader:
return result return result
class BQLShell(bc_query.BQLShell): class BQLShell(bc_query_shell.BQLShell):
pass def on_Select(self, statement: str) -> None:
output_format: str = self.vars['format']
try:
render_func = getattr(self, f'_render_{output_format}')
except AttributeError:
logger.error("unknown output format %r", output_format)
return
try:
logger.debug("compiling query")
compiled_query = bc_query_compile.compile(
statement, self.env_targets, self.env_postings, self.env_entries,
)
logger.debug("executing query")
row_types, rows = bc_query_execute.execute_query(
compiled_query, self.entries, self.options_map,
)
if self.vars['numberify'] and output_format != 'ods':
logger.debug("numberifying query")
row_types, rows = bc_query_numberify.numberify_results(
row_types, rows, self.options_map['dcontext'].build(),
)
except Exception as error:
logger.error(str(error), exc_info=logger.isEnabledFor(logging.DEBUG))
return
if not rows and output_format != 'ods':
print("(empty)", file=self.outfile)
else:
logger.debug("rendering query as %s", output_format)
render_func(row_types, rows)
def _render_csv(self, row_types: RowTypes, rows: Rows) -> None:
bc_query_render.render_csv(
row_types,
rows,
self.options_map['dcontext'],
self.outfile,
self.vars['expand'],
)
def _render_text(self, row_types: RowTypes, rows: Rows) -> None:
with contextlib.ExitStack() as stack:
if self.is_interactive:
output = stack.enter_context(self.get_pager())
else:
output = self.outfile
bc_query_render.render_text(
row_types,
rows,
self.options_map['dcontext'],
output,
self.vars['expand'],
self.vars['boxed'],
self.vars['spaced'],
)
class JoinOperator(enum.Enum): class JoinOperator(enum.Enum):