diff --git a/import2ledger/template.py b/import2ledger/template.py index 160e3fe..6e8ac63 100644 --- a/import2ledger/template.py +++ b/import2ledger/template.py @@ -52,7 +52,15 @@ class TokenTransformer: class AmountTokenTransformer(TokenTransformer): SUPPORTED_OPS = frozenset('+-*/()') - transform_NAME = TokenTransformer._noop_transformer + def __iter__(self): + tokens = super().__iter__() + for token in tokens: + yield token + if token[0] == tokenize.NAME: + break + else: + raise ValueError("no amount in expression") + yield from tokens def transform_NUMBER(self, ttype, tvalue): yield (tokenize.NAME, 'Decimal') @@ -66,6 +74,7 @@ class AmountTokenTransformer(TokenTransformer): name_type, name_value, _, _, _ = next(self.in_tokens) close_type, close_value, _, _, _ = next(self.in_tokens) if (name_type != tokenize.NAME + or name_value != name_value.lower() or close_type != tokenize.OP or close_value != '}'): raise ValueError() @@ -91,9 +100,11 @@ class AccountSplitter: def add(self, account, amount_expr): try: clean_expr = AmountTokenTransformer.from_str(amount_expr).transform() - except ValueError as error: + compiled_expr = compile(clean_expr, self.template_name, 'eval') + except (SyntaxError, tokenize.TokenError, ValueError) as error: raise errors.UserInputConfigurationError(error.args[0], amount_expr) - self.splits[account] = compile(clean_expr, self.template_name, 'eval') + else: + self.splits[account] = compiled_expr def _currency_decimal(self, amount, currency): return decimal.Decimal(babel.numbers.format_currency(amount, currency, '###0.###')) diff --git a/tests/test_templates.py b/tests/test_templates.py index 32c93ea..e7dd551 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,7 +4,7 @@ import decimal import pathlib import pytest -from import2ledger import template +from import2ledger import errors, template from . import DATA_DIR, normalize_whitespace @@ -87,3 +87,25 @@ def test_multivalue(): " Income:RBI -15.00 USD", " Income:Donations -135.00 USD", ] + +@pytest.mark.parametrize('amount_expr', [ + '', + 'name', + '-', + '()', + '+()', + '{}', + '{{}}', + '{()}', + '{name', + 'name}', + '{42}', + '(5).real', + '{amount.real}', + '{amount.is_nan()}', + '{Decimal}', + '{FOO}', +]) +def test_bad_amount_expression(amount_expr): + with pytest.raises(errors.UserInputError): + template.Template(" Income " + amount_expr)