query: Development cleanup.

Reorder classes for nicer readability. Put shorter classes higher up, keep
related classes together.

Add developer-facing comments.

Remove unused imports.
This commit is contained in:
Brett Smith 2021-03-09 10:33:11 -05:00
parent 5893d6a59a
commit c5a2c2d39b

View file

@ -11,12 +11,10 @@ import datetime
import enum import enum
import itertools import itertools
import logging import logging
import re
import sys import sys
from typing import ( from typing import (
cast, cast,
AbstractSet,
Any, Any,
Callable, Callable,
Dict, Dict,
@ -35,8 +33,6 @@ from typing import (
from ..beancount_types import ( from ..beancount_types import (
MetaKey, MetaKey,
MetaValue, MetaValue,
Posting,
Transaction,
) )
from decimal import Decimal from decimal import Decimal
@ -91,6 +87,81 @@ QueryStatement = Union[
bc_query_parser.Select, bc_query_parser.Select,
] ]
# This class supports type checking. Beancount code dynamically sets the
# ``store`` attribute, in bc_query_execute.execute_query().
class Context(bc_query_execute.RowContext):
store: Store
class MetaDocs(bc_query_env.AnyMeta):
"""Return a list of document links from metadata."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, list)
# The second argument is our return type.
# It should match the annotated return type of __call__.
def __call__(self, context: Context) -> List[str]:
raw_value = super().__call__(context)
if isinstance(raw_value, str):
return raw_value.split()
else:
return []
class StrMeta(bc_query_env.AnyMeta):
"""Looks up metadata like AnyMeta, then always returns a string."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, str)
def __call__(self, context: Context) -> str:
raw_value = super().__call__(context)
if raw_value is None:
return ''
else:
return str(raw_value)
class AggregateSet(bc_query_compile.EvalAggregator):
__intypes__ = [object]
def __init__(self, operands: List[str]) -> None:
super().__init__(operands, set)
def allocate(self, allocator: bc_query_execute.Allocator) -> None:
"""Allocate and save an index handle into result storage."""
self.handle = allocator.allocate()
def initialize(self, store: Store) -> None:
"""Prepare result storage for a new aggregation."""
store[self.handle] = self.dtype()
# self.dtype() is our return type, aka the second argument to __init__
# above, aka the annotated return type of __call__.
def update(self, store: Store, context: Context) -> None:
"""Update existing storage with new result data."""
value, = self.eval_args(context)
if isinstance(value, Sequence) and not isinstance(value, (str, tuple)):
store[self.handle].update(value)
else:
store[self.handle].add(value)
def __call__(self, context: Context) -> set:
"""Return the result for an aggregation."""
return context.store[self.handle] # type:ignore[no-any-return]
class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment):
functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment]
functions['meta_docs'] = MetaDocs
functions['str_meta'] = StrMeta
class TargetsEnvironment(bc_query_env.TargetsEnvironment):
functions = FilterPostingsEnvironment.functions.copy()
functions.update(bc_query_env.AGGREGATOR_FUNCTIONS)
functions['set'] = AggregateSet
class BooksLoader: class BooksLoader:
"""Closure to load books with a zero-argument callable """Closure to load books with a zero-argument callable
@ -128,8 +199,91 @@ class BooksLoader:
return result return result
class BQLShell(bc_query_shell.BQLShell):
def __init__(
self,
is_interactive: bool,
loadfun: Callable[[], books.LoadResult],
outfile: TextIO,
default_format: str='text',
do_numberify: bool=False,
rt_wrapper: Optional[rtutil.RT]=None,
) -> None:
super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify)
self.env_postings = FilterPostingsEnvironment()
self.env_targets = TargetsEnvironment()
self.ods = QueryODS(rt_wrapper)
def on_Select(self, statement: QueryStatement) -> None:
output_format: str = self.vars['format']
try:
render_func = getattr(self, f'_render_{output_format}')
except AttributeError:
logger.error("unknown output format %r", output_format)
return
try:
logger.debug("compiling query")
compiled_query = bc_query_compile.compile(
statement, self.env_targets, self.env_postings, self.env_entries,
)
logger.debug("executing query")
row_types, rows = bc_query_execute.execute_query(
compiled_query, self.entries, self.options_map,
)
if self.vars['numberify']:
logger.debug("numberifying query")
row_types, rows = bc_query_numberify.numberify_results(
row_types, rows, self.options_map['dcontext'].build(),
)
except Exception as error:
logger.error(str(error), exc_info=logger.isEnabledFor(logging.DEBUG))
return
if not rows and output_format != 'ods':
print("(empty)", file=self.outfile)
else:
logger.debug("rendering query as %s", output_format)
render_func(statement, row_types, rows)
def _render_csv(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
bc_query_render.render_csv(
row_types,
rows,
self.options_map['dcontext'],
self.outfile,
self.vars['expand'],
)
def _render_ods(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
self.ods.write_query(statement, row_types, rows)
logger.info(
"%s rows of results saved in sheet %s",
len(rows),
self.ods.sheet.getAttribute('name'),
)
def _render_text(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
with contextlib.ExitStack() as stack:
if self.is_interactive:
output = stack.enter_context(self.get_pager())
else:
output = self.outfile
bc_query_render.render_text(
row_types,
rows,
self.options_map['dcontext'],
output,
self.vars['expand'],
self.vars['boxed'],
self.vars['spaced'],
)
class QueryODS(core.BaseODS[NamedTuple, None]): class QueryODS(core.BaseODS[NamedTuple, None]):
META_FNAMES = frozenset([ META_FNAMES = frozenset([
# Names of functions, as defined in Environments, that look up
# posting metadata that could contain documentation links
'any_meta', 'any_meta',
'entry_meta', 'entry_meta',
'meta', 'meta',
@ -264,154 +418,6 @@ class QueryODS(core.BaseODS[NamedTuple, None]):
)) ))
# This class mostly supports type checking. Beancount code dynamically sets the
# ``store`` attribute, in bc_query_execute.execute_query().
class Context(bc_query_execute.RowContext):
store: Store
class MetaDocs(bc_query_env.AnyMeta):
"""Return a list of document links from metadata."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, list)
def __call__(self, context: Context) -> List[str]:
raw_value = super().__call__(context)
if isinstance(raw_value, str):
return raw_value.split()
else:
return []
class StrMeta(bc_query_env.AnyMeta):
"""Looks up metadata like AnyMeta, then always returns a string."""
def __init__(self, operands: List[str]) -> None:
super(bc_query_env.AnyMeta, self).__init__(operands, str)
def __call__(self, context: Context) -> str:
raw_value = super().__call__(context)
if raw_value is None:
return ''
else:
return str(raw_value)
class AggregateSet(bc_query_compile.EvalAggregator):
__intypes__ = [object]
def __init__(self, operands: List[str]) -> None:
super().__init__(operands, set)
def allocate(self, allocator: bc_query_execute.Allocator) -> None:
self.handle = allocator.allocate()
def initialize(self, store: Store) -> None:
store[self.handle] = self.dtype()
def update(self, store: Store, context: Context) -> None:
value, = self.eval_args(context)
if isinstance(value, Sequence) and not isinstance(value, (str, tuple)):
store[self.handle].update(value)
else:
store[self.handle].add(value)
def __call__(self, context: Context) -> set:
return context.store[self.handle] # type:ignore[no-any-return]
class FilterPostingsEnvironment(bc_query_env.FilterPostingsEnvironment):
functions: EnvironmentFunctions = bc_query_env.FilterPostingsEnvironment.functions.copy() # type:ignore[assignment]
functions['meta_docs'] = MetaDocs
functions['str_meta'] = StrMeta
class TargetsEnvironment(bc_query_env.TargetsEnvironment):
functions = FilterPostingsEnvironment.functions.copy()
functions.update(bc_query_env.AGGREGATOR_FUNCTIONS)
functions['set'] = AggregateSet
class BQLShell(bc_query_shell.BQLShell):
def __init__(
self,
is_interactive: bool,
loadfun: Callable[[], books.LoadResult],
outfile: TextIO,
default_format: str='text',
do_numberify: bool=False,
rt_wrapper: Optional[rtutil.RT]=None,
) -> None:
super().__init__(is_interactive, loadfun, outfile, default_format, do_numberify)
self.env_postings = FilterPostingsEnvironment()
self.env_targets = TargetsEnvironment()
self.ods = QueryODS(rt_wrapper)
def on_Select(self, statement: QueryStatement) -> None:
output_format: str = self.vars['format']
try:
render_func = getattr(self, f'_render_{output_format}')
except AttributeError:
logger.error("unknown output format %r", output_format)
return
try:
logger.debug("compiling query")
compiled_query = bc_query_compile.compile(
statement, self.env_targets, self.env_postings, self.env_entries,
)
logger.debug("executing query")
row_types, rows = bc_query_execute.execute_query(
compiled_query, self.entries, self.options_map,
)
if self.vars['numberify']:
logger.debug("numberifying query")
row_types, rows = bc_query_numberify.numberify_results(
row_types, rows, self.options_map['dcontext'].build(),
)
except Exception as error:
logger.error(str(error), exc_info=logger.isEnabledFor(logging.DEBUG))
return
if not rows and output_format != 'ods':
print("(empty)", file=self.outfile)
else:
logger.debug("rendering query as %s", output_format)
render_func(statement, row_types, rows)
def _render_csv(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
bc_query_render.render_csv(
row_types,
rows,
self.options_map['dcontext'],
self.outfile,
self.vars['expand'],
)
def _render_ods(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
self.ods.write_query(statement, row_types, rows)
logger.info(
"%s rows of results saved in sheet %s",
len(rows),
self.ods.sheet.getAttribute('name'),
)
def _render_text(self, statement: QueryStatement, row_types: RowTypes, rows: Rows) -> None:
with contextlib.ExitStack() as stack:
if self.is_interactive:
output = stack.enter_context(self.get_pager())
else:
output = self.outfile
bc_query_render.render_text(
row_types,
rows,
self.options_map['dcontext'],
output,
self.vars['expand'],
self.vars['boxed'],
self.vars['spaced'],
)
class ReportFormat(enum.Enum): class ReportFormat(enum.Enum):
TEXT = 'text' TEXT = 'text'
TXT = TEXT TXT = TEXT
@ -442,6 +448,10 @@ class SetFYDates(argparse.Action):
) -> None: ) -> None:
value = cliutil.year_or_date_arg(str(values)) value = cliutil.year_or_date_arg(str(values))
namespace.start_date = value namespace.start_date = value
# The configuration hasn't been loaded, so we don't know the boundaries
# of a fiscal year yet. But that's okay, because we just need to set
# enough so that when these arguments are passed to a BooksLoader,
# it'll load the right fiscal year.
if isinstance(value, int): if isinstance(value, int):
namespace.stop_date = value namespace.stop_date = value
else: else: