hooks.ledger_entry: Bring in all Ledger-specific code.

This includes the old "template" module, plus the associated
template-loading code from config.
This commit is contained in:
Brett Smith 2017-12-31 17:29:14 -05:00
parent 6d1a7cb57d
commit cdec3d9aab
6 changed files with 317 additions and 347 deletions

View file

@ -9,7 +9,7 @@ import pathlib
import babel
import babel.numbers
from . import errors, strparse, template
from . import errors, strparse
class Configuration:
HOME_PATH = pathlib.Path(os.path.expanduser('~'))
@ -241,24 +241,6 @@ class Configuration:
path = self.get_output_path(section_name)
return self._open_path(path, self.stdout, 'a')
def get_template(self, config_key, section_name=None, factory=template.Template):
section_config = self.get_section(section_name)
try:
template_s = section_config[config_key]
except KeyError:
raise errors.UserInputConfigurationError(
"template not defined in [{}]".format(section_name or self.args.use_config),
config_key,
)
return factory(
template_s,
date_fmt=section_config['date_format'],
signed_currencies=[code.strip().upper() for code in section_config['signed_currencies'].split(',')],
signed_currency_fmt=section_config['signed_currency_format'],
unsigned_currency_fmt=section_config['unsigned_currency_format'],
template_name=config_key,
)
def setup_logger(self, logger, section_name=None):
logger.setLevel(self.get_loglevel(section_name))

View file

@ -1,6 +1,273 @@
from . import HOOK_KINDS
import collections
import datetime
import decimal
import functools
import io
import operator
import re
import tokenize
import babel.numbers
from . import HOOK_KINDS
from .. import errors, strparse
class TokenTransformer:
def __init__(self, source):
try:
source = source.readline
except AttributeError:
pass
self.in_tokens = tokenize.tokenize(source)
@classmethod
def from_bytes(cls, b):
return cls(io.BytesIO(b).readline)
@classmethod
def from_str(cls, s, encoding='utf-8'):
return cls.from_bytes(s.encode(encoding))
def __iter__(self):
for ttype, tvalue, _, _, _ in self.in_tokens:
try:
transformer = getattr(self, 'transform_' + tokenize.tok_name[ttype])
except AttributeError:
raise ValueError("{} token {!r} not supported".format(ttype, tvalue))
yield from transformer(ttype, tvalue)
def _noop_transformer(self, ttype, tvalue):
yield (ttype, tvalue)
transform_ENDMARKER = _noop_transformer
def transform_ENCODING(self, ttype, tvalue):
self.in_encoding = tvalue
return self._noop_transformer(ttype, tvalue)
def transform(self):
out_bytes = tokenize.untokenize(self)
return out_bytes.decode(self.in_encoding)
class AmountTokenTransformer(TokenTransformer):
SUPPORTED_OPS = frozenset('+-*/()')
def __iter__(self):
tokens = super().__iter__()
for token in tokens:
yield token
if token[0] == tokenize.NAME:
break
else:
raise ValueError("no amount in expression")
yield from tokens
def transform_NUMBER(self, ttype, tvalue):
yield (tokenize.NAME, 'Decimal')
yield (tokenize.OP, '(')
yield (tokenize.STRING, repr(tvalue))
yield (tokenize.OP, ')')
def transform_OP(self, ttype, tvalue):
if tvalue == '{':
try:
name_type, name_value, _, _, _ = next(self.in_tokens)
close_type, close_value, _, _, _ = next(self.in_tokens)
if (name_type != tokenize.NAME
or name_value != name_value.lower()
or close_type != tokenize.OP
or close_value != '}'):
raise ValueError()
except (StopIteration, ValueError):
raise ValueError("opening { does not name variable")
yield (tokenize.NAME, name_value)
elif tvalue in self.SUPPORTED_OPS:
yield from self._noop_transformer(ttype, tvalue)
else:
raise ValueError("unsupported operator {!r}".format(tvalue))
class AccountSplitter:
TARGET_LINE_LEN = 78
# -4 because that's how many spaces prefix an account line.
TARGET_ACCTLINE_LEN = TARGET_LINE_LEN - 4
def __init__(self, signed_currencies, signed_currency_fmt, unsigned_currency_fmt,
template_name):
self.splits = []
self.metadata = []
self.signed_currency_fmt = signed_currency_fmt
self.unsigned_currency_fmt = unsigned_currency_fmt
self.signed_currencies = set(signed_currencies)
self.template_name = template_name
self._last_template_vars = object()
def is_empty(self):
return not self.splits
def add(self, account, amount_expr):
try:
clean_expr = AmountTokenTransformer.from_str(amount_expr).transform()
compiled_expr = compile(clean_expr, self.template_name, 'eval')
except (SyntaxError, tokenize.TokenError, ValueError) as error:
raise errors.UserInputConfigurationError(error.args[0], amount_expr)
else:
self.splits.append((account, compiled_expr))
self.metadata.append('')
def set_metadata(self, metadata_s):
self.metadata[-1] = metadata_s
def _currency_decimal(self, amount, currency):
return decimal.Decimal(babel.numbers.format_currency(amount, currency, '###0.###'))
def _balance_amounts(self, amounts, to_amount):
cmp_func = operator.lt if to_amount > 0 else operator.gt
should_balance = functools.partial(cmp_func, 0)
remainder = to_amount
balance_index = None
for index, (_, amount) in enumerate(amounts):
if should_balance(amount):
remainder -= amount
balance_index = index
if balance_index is None:
pass
elif (abs(remainder) / abs(to_amount)) >= decimal.Decimal('.1'):
raise errors.UserInputConfigurationError(
"template can't balance amounts to {}".format(to_amount),
self.template_name,
)
else:
account_name, start_amount = amounts[balance_index]
amounts[balance_index] = (account_name, start_amount + remainder)
def _build_amounts(self, template_vars):
amount_vars = {k: v for k, v in template_vars.items() if isinstance(v, decimal.Decimal)}
amount_vars['Decimal'] = decimal.Decimal
amounts = [
(account, self._currency_decimal(eval(amount_expr, amount_vars),
template_vars['currency']))
for account, amount_expr in self.splits
]
self._balance_amounts(amounts, template_vars['amount'])
self._balance_amounts(amounts, -template_vars['amount'])
return amounts
def _iter_splits(self, template_vars):
amounts = self._build_amounts(template_vars)
if template_vars['currency'] in self.signed_currencies:
amt_fmt = self.signed_currency_fmt
else:
amt_fmt = self.unsigned_currency_fmt
for (account, amount), metadata in zip(amounts, self.metadata):
if amount == 0:
yield ''
else:
account_s = account.format_map(template_vars)
amount_s = babel.numbers.format_currency(amount, template_vars['currency'], amt_fmt)
sep_len = max(2, self.TARGET_ACCTLINE_LEN - len(account_s) - len(amount_s))
yield '\n {}{}{}{}'.format(
account_s, ' ' * sep_len, amount_s,
metadata.format_map(template_vars),
)
def render_next(self, template_vars):
if template_vars is not self._last_template_vars:
self._split_iter = self._iter_splits(template_vars)
self._last_template_vars = template_vars
return next(self._split_iter)
class Template:
ACCOUNT_SPLIT_RE = re.compile(r'(?:\t| )\s*')
DATE_FMT = '%Y/%m/%d'
PAYEE_LINE_RE = re.compile(r'\{(\w*_)*date\}')
SIGNED_CURRENCY_FMT = '¤#,##0.###;¤-#,##0.###'
UNSIGNED_CURRENCY_FMT = '#,##0.### ¤¤'
def __init__(self, template_s, signed_currencies=frozenset(),
date_fmt=DATE_FMT,
signed_currency_fmt=SIGNED_CURRENCY_FMT,
unsigned_currency_fmt=UNSIGNED_CURRENCY_FMT,
template_name='<template>'):
self.date_fmt = date_fmt
self.splitter = AccountSplitter(
signed_currencies, signed_currency_fmt, unsigned_currency_fmt, template_name)
lines = self._template_lines(template_s)
self.format_funcs = []
try:
self.format_funcs.append(next(lines).format_map)
except StopIteration:
return
metadata = []
for line in lines:
if line.startswith(';'):
metadata.append(line)
else:
self._add_str_func(metadata)
metadata = []
line = line.strip()
match = self.ACCOUNT_SPLIT_RE.search(line)
if match is None:
raise errors.UserInputError("no amount expression found", line)
account = line[:match.start()]
amount_expr = line[match.end():]
self.splitter.add(account, amount_expr)
self.format_funcs.append(self.splitter.render_next)
self._add_str_func(metadata)
self.format_funcs.append('\n'.format_map)
def _nonblank_lines(self, s):
for line in s.splitlines(True):
line = line.strip()
if line:
yield line
def _template_lines(self, template_s):
lines = self._nonblank_lines(template_s)
try:
line1 = next(lines)
except StopIteration:
return
if self.PAYEE_LINE_RE.match(line1):
yield '\n' + line1
else:
yield '\n{date} {payee}'
yield line1
yield from lines
def _add_str_func(self, str_seq):
str_flat = ''.join('\n ' + s for s in str_seq)
if not str_flat:
pass
elif self.splitter.is_empty():
self.format_funcs.append(str_flat.format_map)
else:
self.splitter.set_metadata(str_flat)
def render(self, template_vars):
# template_vars must have these keys. Raise a KeyError if not.
template_vars['currency']
template_vars['payee']
try:
date = template_vars['date']
except KeyError:
date = datetime.date.today()
render_vars = {
'amount': strparse.currency_decimal(template_vars['amount']),
'date': date.strftime(self.date_fmt),
}
for key, value in template_vars.items():
if key.endswith('_date'):
render_vars[key] = value.strftime(self.date_fmt)
all_vars = collections.ChainMap(render_vars, template_vars)
return ''.join(f(all_vars) for f in self.format_funcs)
def is_empty(self):
return not self.format_funcs
from .. import errors
class LedgerEntryHook:
KIND = HOOK_KINDS.OUTPUT
@ -8,9 +275,29 @@ class LedgerEntryHook:
def __init__(self, config):
self.config = config
@staticmethod
@functools.lru_cache()
def _load_template(config, section_name, config_key):
section_config = config.get_section(section_name)
try:
template_s = section_config[config_key]
except KeyError:
raise errors.UserInputConfigurationError(
"template not defined in [{}]".format(section_name),
config_key,
)
return Template(
template_s,
date_fmt=section_config['date_format'],
signed_currencies=[code.strip().upper() for code in section_config['signed_currencies'].split(',')],
signed_currency_fmt=section_config['signed_currency_format'],
unsigned_currency_fmt=section_config['unsigned_currency_format'],
template_name=config_key,
)
def run(self, entry_data):
try:
template = self.config.get_template(entry_data['template'])
template = self._load_template(self.config, None, entry_data['template'])
except errors.UserInputConfigurationError as error:
if error.strerror.startswith('template not defined '):
have_template = False

View file

@ -1,268 +0,0 @@
import collections
import datetime
import decimal
import functools
import io
import operator
import re
import tokenize
import babel.numbers
from . import errors, strparse
class TokenTransformer:
def __init__(self, source):
try:
source = source.readline
except AttributeError:
pass
self.in_tokens = tokenize.tokenize(source)
@classmethod
def from_bytes(cls, b):
return cls(io.BytesIO(b).readline)
@classmethod
def from_str(cls, s, encoding='utf-8'):
return cls.from_bytes(s.encode(encoding))
def __iter__(self):
for ttype, tvalue, _, _, _ in self.in_tokens:
try:
transformer = getattr(self, 'transform_' + tokenize.tok_name[ttype])
except AttributeError:
raise ValueError("{} token {!r} not supported".format(ttype, tvalue))
yield from transformer(ttype, tvalue)
def _noop_transformer(self, ttype, tvalue):
yield (ttype, tvalue)
transform_ENDMARKER = _noop_transformer
def transform_ENCODING(self, ttype, tvalue):
self.in_encoding = tvalue
return self._noop_transformer(ttype, tvalue)
def transform(self):
out_bytes = tokenize.untokenize(self)
return out_bytes.decode(self.in_encoding)
class AmountTokenTransformer(TokenTransformer):
SUPPORTED_OPS = frozenset('+-*/()')
def __iter__(self):
tokens = super().__iter__()
for token in tokens:
yield token
if token[0] == tokenize.NAME:
break
else:
raise ValueError("no amount in expression")
yield from tokens
def transform_NUMBER(self, ttype, tvalue):
yield (tokenize.NAME, 'Decimal')
yield (tokenize.OP, '(')
yield (tokenize.STRING, repr(tvalue))
yield (tokenize.OP, ')')
def transform_OP(self, ttype, tvalue):
if tvalue == '{':
try:
name_type, name_value, _, _, _ = next(self.in_tokens)
close_type, close_value, _, _, _ = next(self.in_tokens)
if (name_type != tokenize.NAME
or name_value != name_value.lower()
or close_type != tokenize.OP
or close_value != '}'):
raise ValueError()
except (StopIteration, ValueError):
raise ValueError("opening { does not name variable")
yield (tokenize.NAME, name_value)
elif tvalue in self.SUPPORTED_OPS:
yield from self._noop_transformer(ttype, tvalue)
else:
raise ValueError("unsupported operator {!r}".format(tvalue))
class AccountSplitter:
TARGET_LINE_LEN = 78
# -4 because that's how many spaces prefix an account line.
TARGET_ACCTLINE_LEN = TARGET_LINE_LEN - 4
def __init__(self, signed_currencies, signed_currency_fmt, unsigned_currency_fmt,
template_name):
self.splits = []
self.metadata = []
self.signed_currency_fmt = signed_currency_fmt
self.unsigned_currency_fmt = unsigned_currency_fmt
self.signed_currencies = set(signed_currencies)
self.template_name = template_name
self._last_template_vars = object()
def is_empty(self):
return not self.splits
def add(self, account, amount_expr):
try:
clean_expr = AmountTokenTransformer.from_str(amount_expr).transform()
compiled_expr = compile(clean_expr, self.template_name, 'eval')
except (SyntaxError, tokenize.TokenError, ValueError) as error:
raise errors.UserInputConfigurationError(error.args[0], amount_expr)
else:
self.splits.append((account, compiled_expr))
self.metadata.append('')
def set_metadata(self, metadata_s):
self.metadata[-1] = metadata_s
def _currency_decimal(self, amount, currency):
return decimal.Decimal(babel.numbers.format_currency(amount, currency, '###0.###'))
def _balance_amounts(self, amounts, to_amount):
cmp_func = operator.lt if to_amount > 0 else operator.gt
should_balance = functools.partial(cmp_func, 0)
remainder = to_amount
balance_index = None
for index, (_, amount) in enumerate(amounts):
if should_balance(amount):
remainder -= amount
balance_index = index
if balance_index is None:
pass
elif (abs(remainder) / abs(to_amount)) >= decimal.Decimal('.1'):
raise errors.UserInputConfigurationError(
"template can't balance amounts to {}".format(to_amount),
self.template_name,
)
else:
account_name, start_amount = amounts[balance_index]
amounts[balance_index] = (account_name, start_amount + remainder)
def _build_amounts(self, template_vars):
amount_vars = {k: v for k, v in template_vars.items() if isinstance(v, decimal.Decimal)}
amount_vars['Decimal'] = decimal.Decimal
amounts = [
(account, self._currency_decimal(eval(amount_expr, amount_vars),
template_vars['currency']))
for account, amount_expr in self.splits
]
self._balance_amounts(amounts, template_vars['amount'])
self._balance_amounts(amounts, -template_vars['amount'])
return amounts
def _iter_splits(self, template_vars):
amounts = self._build_amounts(template_vars)
if template_vars['currency'] in self.signed_currencies:
amt_fmt = self.signed_currency_fmt
else:
amt_fmt = self.unsigned_currency_fmt
for (account, amount), metadata in zip(amounts, self.metadata):
if amount == 0:
yield ''
else:
account_s = account.format_map(template_vars)
amount_s = babel.numbers.format_currency(amount, template_vars['currency'], amt_fmt)
sep_len = max(2, self.TARGET_ACCTLINE_LEN - len(account_s) - len(amount_s))
yield '\n {}{}{}{}'.format(
account_s, ' ' * sep_len, amount_s,
metadata.format_map(template_vars),
)
def render_next(self, template_vars):
if template_vars is not self._last_template_vars:
self._split_iter = self._iter_splits(template_vars)
self._last_template_vars = template_vars
return next(self._split_iter)
class Template:
ACCOUNT_SPLIT_RE = re.compile(r'(?:\t| )\s*')
DATE_FMT = '%Y/%m/%d'
PAYEE_LINE_RE = re.compile(r'\{(\w*_)*date\}')
SIGNED_CURRENCY_FMT = '¤#,##0.###;¤-#,##0.###'
UNSIGNED_CURRENCY_FMT = '#,##0.### ¤¤'
def __init__(self, template_s, signed_currencies=frozenset(),
date_fmt=DATE_FMT,
signed_currency_fmt=SIGNED_CURRENCY_FMT,
unsigned_currency_fmt=UNSIGNED_CURRENCY_FMT,
template_name='<template>'):
self.date_fmt = date_fmt
self.splitter = AccountSplitter(
signed_currencies, signed_currency_fmt, unsigned_currency_fmt, template_name)
lines = self._template_lines(template_s)
self.format_funcs = []
try:
self.format_funcs.append(next(lines).format_map)
except StopIteration:
return
metadata = []
for line in lines:
if line.startswith(';'):
metadata.append(line)
else:
self._add_str_func(metadata)
metadata = []
line = line.strip()
match = self.ACCOUNT_SPLIT_RE.search(line)
if match is None:
raise errors.UserInputError("no amount expression found", line)
account = line[:match.start()]
amount_expr = line[match.end():]
self.splitter.add(account, amount_expr)
self.format_funcs.append(self.splitter.render_next)
self._add_str_func(metadata)
self.format_funcs.append('\n'.format_map)
def _nonblank_lines(self, s):
for line in s.splitlines(True):
line = line.strip()
if line:
yield line
def _template_lines(self, template_s):
lines = self._nonblank_lines(template_s)
try:
line1 = next(lines)
except StopIteration:
return
if self.PAYEE_LINE_RE.match(line1):
yield '\n' + line1
else:
yield '\n{date} {payee}'
yield line1
yield from lines
def _add_str_func(self, str_seq):
str_flat = ''.join('\n ' + s for s in str_seq)
if not str_flat:
pass
elif self.splitter.is_empty():
self.format_funcs.append(str_flat.format_map)
else:
self.splitter.set_metadata(str_flat)
def render(self, template_vars):
# template_vars must have these keys. Raise a KeyError if not.
template_vars['currency']
template_vars['payee']
try:
date = template_vars['date']
except KeyError:
date = datetime.date.today()
render_vars = {
'amount': strparse.currency_decimal(template_vars['amount']),
'date': date.strftime(self.date_fmt),
}
for key, value in template_vars.items():
if key.endswith('_date'):
render_vars[key] = value.strftime(self.date_fmt)
all_vars = collections.ChainMap(render_vars, template_vars)
return ''.join(f(all_vars) for f in self.format_funcs)
def is_empty(self):
return not self.format_funcs

View file

@ -1,3 +1,9 @@
[DEFAULT]
date_format = %%Y-%%m-%%d
signed_currencies = USD, CAD
signed_currency_format = ¤#,##0.###
unsigned_currency_format = #,##0.### ¤¤
[Simplest]
template = Accrued:Accounts Receivable {amount}
Income:Donations -{amount}

View file

@ -21,32 +21,6 @@ def config_from_file(path, arglist=[], stdout=None, stderr=None):
arglist = ['-C', path.as_posix(), *arglist, os.devnull]
return config.Configuration(arglist, stdout, stderr)
def test_defaults():
config = config_from_file('test_config.ini', ['--sign', 'GBP', '-O', 'out_arg'])
factory = mock.Mock(name='Template')
template = config.get_template('one', 'Templates', factory)
assert factory.called
kwargs = factory.call_args[1]
assert list(kwargs.pop('signed_currencies', '')) == ['GBP']
assert kwargs == {
'date_fmt': '%Y-%m-%d',
'signed_currency_fmt': kwargs['signed_currency_fmt'],
'template_name': 'one',
'unsigned_currency_fmt': kwargs['unsigned_currency_fmt'],
}
def test_template_parsing():
config = config_from_file('test_config.ini')
factory = mock.Mock(name='Template')
template = config.get_template('two', 'Templates', factory)
try:
tmpl_s = factory.call_args[0][0]
except IndexError as error:
assert False, error
assert "\n;Tag1: {value}\n" in tmpl_s
assert "\nIncome:Donations -{amount}\n" in tmpl_s
assert "\n;IncomeTag: Donations\n" in tmpl_s
def test_get_section():
config = config_from_file('test_config.ini', ['--date-format', '%m/%d/%Y'])
section = config.get_section('Templates')
@ -126,9 +100,3 @@ def test_bad_loglevel():
with bad_config('wraning'):
config = config_from_file('test_config.ini', ['-c', 'Bad Loglevel'])
config.get_loglevel()
def test_undefined_template():
template_name = 'template nonexistent'
config = config_from_file(os.devnull)
with bad_config(template_name):
config.get_template(template_name)

View file

@ -7,7 +7,7 @@ import io
import pathlib
import pytest
from import2ledger import errors, template
from import2ledger import errors
from import2ledger.hooks import ledger_entry
from . import DATA_DIR, normalize_whitespace
@ -19,7 +19,7 @@ with pathlib.Path(DATA_DIR, 'templates.ini').open() as conffile:
config.read_file(conffile)
def template_from(section_name, *args, **kwargs):
return template.Template(config[section_name]['template'], *args, **kwargs)
return ledger_entry.Template(config[section_name]['template'], *args, **kwargs)
def template_vars(payee, amount, currency='USD', date=DATE, other_vars=None):
call_vars = {
@ -27,6 +27,7 @@ def template_vars(payee, amount, currency='USD', date=DATE, other_vars=None):
'currency': currency,
'date': date,
'payee': payee,
'template': 'template',
}
if other_vars is None:
return call_vars
@ -62,7 +63,7 @@ def test_currency_formatting():
assert_easy_render(tmpl, 'CC', '7.99', 'USD', '2015/03/14', '$7.99')
def test_empty_template():
tmpl = template.Template("\n \n")
tmpl = ledger_entry.Template("\n \n")
assert tmpl.render(template_vars('BB', '8.99')) == ''
assert tmpl.is_empty()
@ -201,49 +202,43 @@ def test_line1_not_custom_payee():
])
def test_bad_amount_expression(amount_expr):
with pytest.raises(errors.UserInputError):
template.Template(" Income " + amount_expr)
ledger_entry.Template(" Income " + amount_expr)
class Config:
def __init__(self):
def __init__(self, use_section):
self.section_name = use_section
self.stdout = io.StringIO()
@contextlib.contextmanager
def open_output_file(self):
yield self.stdout
def get_template(self, key):
try:
return template_from(key)
except KeyError:
raise errors.UserInputConfigurationError(
"template not defined in test config", key)
def get_section(self, name=None):
return config[self.section_name]
def run_hook(entry_data):
hook_config = Config()
def run_hook(entry_data, config_section):
hook_config = Config(config_section)
hook = ledger_entry.LedgerEntryHook(hook_config)
assert hook.run(entry_data) is None
stdout = hook_config.stdout.getvalue()
return normalize_whitespace(stdout).splitlines()
def hook_vars(template_key, payee, amount):
return template_vars(payee, amount, other_vars={'template': template_key})
def test_hook_renders_template():
entry_data = hook_vars('Simplest', 'BB', '0.99')
lines = run_hook(entry_data)
entry_data = template_vars('BB', '0.99')
lines = run_hook(entry_data, 'Simplest')
assert lines == [
"",
"2015/03/14 BB",
" Accrued:Accounts Receivable 0.99 USD",
" Income:Donations -0.99 USD",
"2015-03-14 BB",
" Accrued:Accounts Receivable $0.99",
" Income:Donations -$0.99",
]
def test_hook_handles_empty_template():
entry_data = hook_vars('Empty', 'CC', 1)
assert not run_hook(entry_data)
entry_data = template_vars('CC', 1)
assert not run_hook(entry_data, 'Empty')
def test_hook_handles_template_undefined():
entry_data = hook_vars('Nonexistent', 'DD', 1)
assert not run_hook(entry_data)
entry_data = template_vars('DD', 1)
assert not run_hook(entry_data, 'Nonexistent')