diff --git a/conservancy_beancount/reports/rewrite.py b/conservancy_beancount/reports/rewrite.py
new file mode 100644
index 0000000..40cfda5
--- /dev/null
+++ b/conservancy_beancount/reports/rewrite.py
@@ -0,0 +1,375 @@
+"""rewrite.py - Post rewriting for financial reports"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import abc
+import datetime
+import decimal
+import enum
+import logging
+import operator as opmod
+import re
+
+from typing import (
+ Callable,
+ Dict,
+ Generic,
+ IO,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Pattern,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+from ..beancount_types import (
+ Meta,
+ MetaKey,
+ MetaValue,
+)
+
+from pathlib import Path
+
+import yaml
+
+from .. import data
+
+Decimal = decimal.Decimal
+T = TypeVar('T')
+TestCallable = Callable[[T, T], bool]
+
+CMP_OPS: Mapping[str, TestCallable] = {
+ '==': opmod.eq,
+ '>=': opmod.ge,
+ '>': opmod.gt,
+ '<=': opmod.le,
+ '<': opmod.lt,
+ '!=': opmod.ne,
+}
+
+# First half of this regexp is pseudo-attribute access.
+# Second half is metadata keys, per the Beancount syntax docs.
+SUBJECT_PAT = r'((?:\.\w+)+|[a-z][-\w]*)\b\s*'
+
+logger = logging.getLogger('conservancy_beancount.reports.rewrite')
+
+class _Registry(Generic[T]):
+ def __init__(self,
+ description: str,
+ parser: Union[str, Pattern],
+ default: Type[T],
+ *others: Tuple[str, Type[T]],
+ ) -> None:
+ if isinstance(parser, str):
+ parser = re.compile(parser)
+ self.description = description
+ self.parser = parser
+ self.default = default
+ self.registry: Mapping[str, Type[T]] = dict(others)
+
+ def parse(self, s: str) -> T:
+ match = self.parser.match(s)
+ if match is None:
+ raise ValueError(f"could not parse {self.description} {s!r}")
+ subject = match.group(1)
+ operator = match.group(2)
+ operand = s[match.end():].strip()
+ if not subject.startswith('.'):
+ # FIXME: To avoid this type ignore, I would have to define a common
+ # superclass for Tester and Setter that provides a useful signature
+ # for __init__, including the versions that deal with Metadata,
+ # and then use that as the bound for our type variable.
+ # Not a priority right now.
+ return self.default(subject, operator, operand) # type:ignore[call-arg]
+ try:
+ retclass = self.registry[subject]
+ except KeyError:
+ raise ValueError(f"unknown subject in {self.description} {subject!r}") from None
+ else:
+ return retclass(operator, operand) # type:ignore[call-arg]
+
+
+class Tester(Generic[T], metaclass=abc.ABCMeta):
+ OPS: Mapping[str, TestCallable] = CMP_OPS
+
+ def __init__(self, operator: str, operand: str) -> None:
+ try:
+ self.op_func = self.OPS[operator]
+ except KeyError:
+ raise ValueError(f"unsupported operator {operator!r}") from None
+ self.operand = self.parse_operand(operand)
+
+ @staticmethod
+ @abc.abstractmethod
+ def parse_operand(operand: str) -> T: ...
+
+ @abc.abstractmethod
+ def post_get(self, post: data.Posting) -> T: ...
+
+ def __call__(self, post: data.Posting) -> bool:
+ return self.op_func(self.post_get(post), self.operand)
+
+
+class AccountTest(Tester[str]):
+ def __init__(self, operator: str, operand: str) -> None:
+ if operator == 'in':
+ self.under_args = operand.split()
+ for name in self.under_args:
+ self.parse_operand(name)
+ else:
+ super().__init__(operator, operand)
+
+ @staticmethod
+ def parse_operand(operand: str) -> str:
+ if data.Account.is_account(f'{operand}:RootsOK'):
+ return operand
+ else:
+ raise ValueError(f"invalid account name {operand!r}")
+
+ def post_get(self, post: data.Posting) -> str:
+ return post.account
+
+ def __call__(self, post: data.Posting) -> bool:
+ try:
+ return post.account.is_under(*self.under_args) is not None
+ except AttributeError:
+ return super().__call__(post)
+
+
+class DateTest(Tester[datetime.date]):
+ @staticmethod
+ def parse_operand(operand: str) -> datetime.date:
+ return datetime.datetime.strptime(operand, '%Y-%m-%d').date()
+
+ def post_get(self, post: data.Posting) -> datetime.date:
+ return post.meta.date
+
+
+class MetadataTest(Tester[Optional[MetaValue]]):
+ def __init__(self, key: MetaKey, operator: str, operand: str) -> None:
+ super().__init__(operator, operand)
+ self.key = key
+
+ @staticmethod
+ def parse_operand(operand: str) -> str:
+ return operand
+
+ def post_get(self, post: data.Posting) -> Optional[MetaValue]:
+ return post.meta.get(self.key)
+
+
+class NumberTest(Tester[Decimal]):
+ @staticmethod
+ def parse_operand(operand: str) -> Decimal:
+ try:
+ return Decimal(operand)
+ except decimal.DecimalException:
+ raise ValueError(f"could not parse decimal {operand!r}")
+
+ def post_get(self, post: data.Posting) -> Decimal:
+ return post.units.number
+
+
+TestRegistry: _Registry[Tester] = _Registry(
+ 'condition',
+ '^{}{}'.format(
+ SUBJECT_PAT,
+ r'({}|in)'.format('|'.join(re.escape(s) for s in Tester.OPS)),
+ ),
+ MetadataTest,
+ ('.account', AccountTest),
+ ('.date', DateTest),
+ ('.number', NumberTest),
+)
+
+class Setter(Generic[T], metaclass=abc.ABCMeta):
+ _regparser = re.compile(r'^{}{}'.format(
+ SUBJECT_PAT,
+ r'',
+ ))
+ _regtype = 'setter'
+
+ @abc.abstractmethod
+ def __call__(self, post: data.Posting) -> Tuple[str, T]: ...
+
+
+class AccountSet(Setter[data.Account]):
+ def __init__(self, operator: str, value: str) -> None:
+ if operator != '=':
+ raise ValueError(f"unsupported operator for account {operator!r}")
+ self.value = data.Account(AccountTest.parse_operand(value))
+
+ def __call__(self, post: data.Posting) -> Tuple[str, data.Account]:
+ return ('account', self.value)
+
+
+class MetadataSet(Setter[str]):
+ def __init__(self, key: str, operator: str, value: str) -> None:
+ if operator != '=':
+ raise ValueError(f"unsupported operator for metadata {operator!r}")
+ self.key = key
+ self.value = value
+
+ def __call__(self, post: data.Posting) -> Tuple[str, str]:
+ return (self.key, self.value)
+
+
+class NumberSet(Setter[data.Amount]):
+ def __init__(self, operator: str, value: str) -> None:
+ if operator != '*=':
+ raise ValueError(f"unsupported operator for number {operator!r}")
+ self.value = NumberTest.parse_operand(value)
+
+ def __call__(self, post: data.Posting) -> Tuple[str, data.Amount]:
+ number = post.units.number * self.value
+ return ('units', post.units._replace(number=number))
+
+
+SetRegistry: _Registry[Setter] = _Registry(
+ 'action',
+ rf'^{SUBJECT_PAT}([-+/*]?=)',
+ MetadataSet,
+ ('.account', AccountSet),
+ ('.number', NumberSet),
+)
+
+class _RootAccount(enum.Enum):
+ Assets = 'Assets'
+ Liabilities = 'Liabilities'
+ Equity = 'Equity'
+
+ @classmethod
+ def from_account(cls, name: str) -> '_RootAccount':
+ root, _, _ = name.partition(':')
+ try:
+ return cls[root]
+ except KeyError:
+ return cls.Equity
+
+
+class RewriteRule:
+ def __init__(self, source: Mapping[str, List[str]]) -> None:
+ self.new_meta: List[Sequence[MetadataSet]] = []
+ self.rewrites: List[Sequence[Setter]] = []
+ for key, rules in source.items():
+ if key == 'if':
+ self.tests = [TestRegistry.parse(rule) for rule in rules]
+ else:
+ new_meta: List[MetadataSet] = []
+ rewrites: List[Setter] = []
+ for rule_s in rules:
+ setter = SetRegistry.parse(rule_s)
+ if isinstance(setter, MetadataSet):
+ new_meta.append(setter)
+ elif any(isinstance(t, type(setter)) for t in rewrites):
+ raise ValueError(f"rule conflicts with earlier action: {rule_s!r}")
+ else:
+ rewrites.append(setter)
+ self.new_meta.append(new_meta)
+ self.rewrites.append(rewrites)
+
+ try:
+ if_ok = any(self.tests)
+ except AttributeError:
+ if_ok = False
+ if not if_ok:
+ raise ValueError("no `if` condition in rule") from None
+
+ account_conditions: Set[_RootAccount] = set()
+ for test in self.tests:
+ if isinstance(test, AccountTest):
+ try:
+ operands = test.under_args
+ except AttributeError:
+ operands = [test.operand]
+ account_conditions.update(_RootAccount.from_account(s) for s in operands)
+ if len(account_conditions) == 1:
+ account_condition: Optional[_RootAccount] = account_conditions.pop()
+ else:
+ account_condition = None
+
+ number_reallocation = Decimal()
+ for rewrite in self.rewrites:
+ rewrite_number = Decimal(1)
+ for rule in rewrite:
+ if isinstance(rule, AccountSet):
+ new_root = _RootAccount.from_account(rule.value)
+ if new_root is not account_condition:
+ raise ValueError(
+ f"cannot assign {new_root} account "
+ f"when `if` checks for {account_condition}",
+ )
+ elif isinstance(rule, NumberSet):
+ rewrite_number = rule.value
+ number_reallocation += rewrite_number
+
+ if not number_reallocation:
+ raise ValueError("no rewrite actions in rule")
+ elif number_reallocation != 1:
+ raise ValueError(f"rule multiplies number by {number_reallocation}")
+
+ def match(self, post: data.Posting) -> bool:
+ return all(test(post) for test in self.tests)
+
+ def rewrite(self, post: data.Posting) -> Iterator[data.Posting]:
+ for rewrite, new_meta in zip(self.rewrites, self.new_meta):
+ kwargs = dict(setter(post) for setter in rewrite)
+ if new_meta:
+ meta = post.meta.detached()
+ meta.update(meta_setter(post) for meta_setter in new_meta)
+ kwargs['meta'] = meta
+ yield post._replace(**kwargs)
+
+
+class RewriteRuleset:
+ def __init__(self, rules: Iterable[RewriteRule]) -> None:
+ self.rules = list(rules)
+
+ def rewrite(self, posts: Iterable[data.Posting]) -> Iterator[data.Posting]:
+ for post in posts:
+ for rule in self.rules:
+ if rule.match(post):
+ yield from rule.rewrite(post)
+ break
+ else:
+ yield post
+
+ @classmethod
+ def from_yaml(cls, source: Union[str, IO, Path]) -> 'RewriteRuleset':
+ if isinstance(source, Path):
+ with source.open() as source_file:
+ return cls.from_yaml(source_file)
+ doc = yaml.safe_load(source)
+ if not isinstance(doc, list):
+ raise ValueError("YAML root element is not a list")
+ for number, item in enumerate(doc, 1):
+ if not isinstance(item, Mapping):
+ raise ValueError(f"YAML item {number} is not a rule hash")
+ for key, value in item.items():
+ if not isinstance(value, list):
+ raise ValueError(f"YAML item {number} {key!r} value is not a list")
+ elif not all(isinstance(s, str) for s in value):
+ raise ValueError(f"YAML item {number} {key!r} value is not all strings")
+ try:
+ logger.debug("loaded %s rewrite rules from YAML", number)
+ except NameError:
+ logger.warning("YAML source is empty; no rewrite rules loaded")
+ return cls(RewriteRule(src) for src in doc)
diff --git a/tests/test_reports_rewrite.py b/tests/test_reports_rewrite.py
new file mode 100644
index 0000000..b7a03b0
--- /dev/null
+++ b/tests/test_reports_rewrite.py
@@ -0,0 +1,348 @@
+"""test_reports_rewrite - Unit tests for report rewrite functionality"""
+# Copyright © 2020 Brett Smith
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import datetime
+
+import pytest
+
+from decimal import Decimal
+
+import yaml
+
+from . import testutil
+
+from conservancy_beancount import data
+from conservancy_beancount.reports import rewrite
+
+CMP_OPS = frozenset('< <= == != >= >'.split())
+
+@pytest.mark.parametrize('name', ['Equity:Other', 'Expenses:Other', 'Income:Other'])
+@pytest.mark.parametrize('operator', CMP_OPS)
+def test_account_condition(name, operator):
+ operand = 'Expenses:Other'
+ txn = testutil.Transaction(postings=[(name, -5)])
+ post, = data.Posting.from_txn(txn)
+ tester = rewrite.AccountTest(operator, operand)
+ assert tester(post) == eval(f'name {operator} operand')
+
+@pytest.mark.parametrize('name,expected', [
+ ('Expenses:Postage', True),
+ ('Expenses:Tax', True),
+ ('Expenses:Tax:Sales', True),
+ ('Expenses:Tax:VAT', True),
+ ('Expenses:Taxes', False),
+ ('Expenses:Other', False),
+ ('Liabilities:Tax', False),
+])
+def test_account_in_condition(name, expected):
+ txn = testutil.Transaction(postings=[(name, 5)])
+ post, = data.Posting.from_txn(txn)
+ tester = rewrite.AccountTest('in', 'Expenses:Tax Expenses:Postage')
+ assert tester(post) == expected
+
+@pytest.mark.parametrize('n', range(3, 12, 3))
+@pytest.mark.parametrize('operator', CMP_OPS)
+def test_date_condition(n, operator):
+ date = datetime.date(2020, n, n)
+ txn = testutil.Transaction(date=date, postings=[
+ ('Income:Other', -5),
+ ])
+ post, = data.Posting.from_txn(txn)
+ tester = rewrite.DateTest(operator, '2020-06-06')
+ assert tester(post) == eval(f'n {operator} 6')
+
+@pytest.mark.parametrize('value', ['test', 'testvalue', 'testzed'])
+@pytest.mark.parametrize('operator', CMP_OPS)
+def test_metadata_condition(value, operator):
+ key = 'testkey'
+ operand = 'testvalue'
+ txn = testutil.Transaction(postings=[
+ ('Income:Other', -5, {key: value}),
+ ])
+ post, = data.Posting.from_txn(txn)
+ tester = rewrite.MetadataTest(key, operator, operand)
+ assert tester(post) == eval(f'value {operator} operand')
+
+@pytest.mark.parametrize('value', ['4.5', '4.75', '5'])
+@pytest.mark.parametrize('operator', CMP_OPS)
+def test_number_condition(value, operator):
+ operand = '4.75'
+ txn = testutil.Transaction(postings=[
+ ('Expenses:Other', value),
+ ])
+ post, = data.Posting.from_txn(txn)
+ tester = rewrite.NumberTest(operator, operand)
+ assert tester(post) == eval(f'value {operator} operand')
+
+@pytest.mark.parametrize('subject,operand', [
+ ('.account', 'Income:Other'),
+ ('.date', '1990-05-10'),
+ ('.number', '5.79'),
+ ('testkey', 'testvalue'),
+])
+@pytest.mark.parametrize('operator', CMP_OPS)
+def test_parse_good_condition(subject, operator, operand):
+ actual = rewrite.TestRegistry.parse(f'{subject}{operator}{operand}')
+ if subject == '.account':
+ assert isinstance(actual, rewrite.AccountTest)
+ assert actual.operand == operand
+ elif subject == '.date':
+ assert isinstance(actual, rewrite.DateTest)
+ assert actual.operand == datetime.date(1990, 5, 10)
+ elif subject == '.number':
+ assert isinstance(actual, rewrite.NumberTest)
+ assert actual.operand == Decimal(operand)
+ else:
+ assert isinstance(actual, rewrite.MetadataTest)
+ assert actual.key == 'testkey'
+ assert actual.operand == 'testvalue'
+
+@pytest.mark.parametrize('cond_s', [
+ '.account = Income:Other', # Bad operator
+ '.account===Equity:Other', # Bad operand (`=Equity:Other` is not valid)
+ '.account in foo', # Bad operand
+ '.date == 1990-90-5', # Bad operand
+ '.date in 1990-9-9', # Bad operator
+ '.number > 0xff', # Bad operand
+ '.number in 16', # Bad operator
+ 'testkey in foo', # Bad operator
+ 'units.number == 5', # Bad subject (syntax)
+ '.units == 5', # Bad subject (unknown)
+])
+def test_parse_bad_condition(cond_s):
+ with pytest.raises(ValueError):
+ rewrite.TestRegistry.parse(cond_s)
+
+@pytest.mark.parametrize('value', ['Equity:Other', 'Income:Other'])
+def test_account_set(value):
+ value = data.Account(value)
+ txn = testutil.Transaction(postings=[
+ ('Expenses:Other', 5),
+ ])
+ post, = data.Posting.from_txn(txn)
+ setter = rewrite.AccountSet('=', value)
+ assert setter(post) == ('account', value)
+
+@pytest.mark.parametrize('key', ['aa', 'bb'])
+def test_metadata_set(key):
+ txn_meta = {'filename': 'metadata_set', 'lineno': 100}
+ post_meta = {'aa': 'one', 'bb': 'two'}
+ meta = {'aa': 'one', 'bb': 'two'}
+ txn = testutil.Transaction(**txn_meta, postings=[
+ ('Expenses:Other', 5, post_meta),
+ ])
+ post, = data.Posting.from_txn(txn)
+ setter = rewrite.MetadataSet(key, '=', 'newvalue')
+ assert setter(post) == (key, 'newvalue')
+
+@pytest.mark.parametrize('value', ['0.25', '-.5', '1.9'])
+@pytest.mark.parametrize('currency', ['USD', 'EUR', 'INR'])
+def test_number_set(value, currency):
+ txn = testutil.Transaction(postings=[
+ ('Expenses:Other', 5, currency),
+ ])
+ post, = data.Posting.from_txn(txn)
+ setter = rewrite.NumberSet('*=', value)
+ assert setter(post) == ('units', testutil.Amount(Decimal(value) * 5, currency))
+
+@pytest.mark.parametrize('subject,operator,operand', [
+ ('.account', '=', 'Income:Other'),
+ ('.number', '*=', '.5'),
+ ('.number', '*=', '-1'),
+ ('testkey', '=', 'testvalue'),
+])
+def test_parse_good_set(subject, operator, operand):
+ actual = rewrite.SetRegistry.parse(f'{subject}{operator}{operand}')
+ if subject == '.account':
+ assert isinstance(actual, rewrite.AccountSet)
+ assert actual.value == operand
+ elif subject == '.number':
+ assert isinstance(actual, rewrite.NumberSet)
+ assert actual.value == Decimal(operand)
+ else:
+ assert isinstance(actual, rewrite.MetadataSet)
+ assert actual.key == subject
+ assert actual.value == operand
+
+@pytest.mark.parametrize('set_s', [
+ '.account==Equity:Other', # Bad operand (`=Equity:Other` is not valid)
+ '.account*=2', # Bad operator
+ '.date = 2020-02-20', # Bad subject
+ '.number*=0xff', # Bad operand
+ '.number=5', # Bad operator
+ 'testkey += foo', # Bad operator
+ 'testkey *= 3', # Bad operator
+])
+def test_parse_bad_set(set_s):
+ with pytest.raises(ValueError):
+ rewrite.SetRegistry.parse(set_s)
+
+def test_good_rewrite_rule():
+ rule = rewrite.RewriteRule({
+ 'if': ['.account in Income'],
+ 'add': ['income-type = Other'],
+ })
+ txn = testutil.Transaction(postings=[
+ ('Assets:Cash', 10),
+ ('Income:Other', -10),
+ ])
+ cash_post, inc_post = data.Posting.from_txn(txn)
+ assert not rule.match(cash_post)
+ assert rule.match(inc_post)
+ new_post, = rule.rewrite(inc_post)
+ assert new_post.account == 'Income:Other'
+ assert new_post.units == testutil.Amount(-10)
+ assert new_post.meta.pop('income-type', None) == 'Other'
+ assert new_post.meta
+ assert new_post.meta.date == txn.date
+
+def test_complicated_rewrite_rule():
+ account = 'Income:Donations'
+ income_key = 'income-type'
+ income_type = 'Refund'
+ rule = rewrite.RewriteRule({
+ 'if': ['.account == Expenses:Refunds'],
+ 'project': [
+ f'.account = {account}',
+ '.number *= .8',
+ f'{income_key} = {income_type}',
+ ],
+ 'general': [
+ f'.account = {account}',
+ '.number *= .2',
+ f'{income_key} = {income_type}',
+ 'project = Conservancy',
+ ],
+ })
+ txn = testutil.Transaction(postings=[
+ ('Assets:Cash', -12),
+ ('Expenses:Refunds', 12, {'project': 'Bravo'}),
+ ])
+ cash_post, src_post = data.Posting.from_txn(txn)
+ assert not rule.match(cash_post)
+ assert rule.match(src_post)
+ proj_post, gen_post = rule.rewrite(src_post)
+ assert proj_post.account == 'Income:Donations'
+ assert proj_post.units == testutil.Amount('9.60')
+ assert proj_post.meta[income_key] == income_type
+ assert proj_post.meta['project'] == 'Bravo'
+ assert gen_post.account == 'Income:Donations'
+ assert gen_post.units == testutil.Amount('2.40')
+ assert gen_post.meta[income_key] == income_type
+ assert gen_post.meta['project'] == 'Conservancy'
+
+@pytest.mark.parametrize('source', [
+ # Account assignments
+ {'if': ['.account in Income Expenses'], 'then': ['.account = Equity']},
+ {'if': ['.account == Assets:PettyCash'], 'then': ['.account = Assets:Cash']},
+ {'if': ['.account == Liabilities:CreditCard'], 'then': ['.account = Liabilities:Visa']},
+ # Number splits
+ {'if': ['.date >= 2020-01-01'], 'a': ['.number *= 2'], 'b': ['.number *= -1']},
+ {'if': ['.date >= 2020-01-02'], 'a': ['.number *= .85'], 'b': ['.number *= .15']},
+ # Metadata assignment
+ {'if': ['a==1'], 'then': ['b=2', 'c=3']},
+])
+def test_valid_rewrite_rule(source):
+ assert rewrite.RewriteRule(source)
+
+@pytest.mark.parametrize('source', [
+ # Incomplete rules
+ {},
+ {'if': ['.account in Equity']},
+ {'a': ['.account = Income:Other'], 'b': ['.account = Expenses:Other']},
+ # Condition/assignment mixup
+ {'if': ['.account = Equity:Other'], 'then': ['equity-type = other']},
+ {'if': ['.account == Equity:Other'], 'then': ['equity-type != other']},
+ # Cross-category account assignment
+ {'if': ['.date >= 2020-01-01'], 'then': ['.account = Assets:Cash']},
+ {'if': ['.account in Equity'], 'then': ['.account = Assets:Cash']},
+ # Number reallocation != 1
+ {'if': ['.date >= 2020-01-01'], 'then': ['.number *= .5']},
+ {'if': ['.date >= 2020-01-01'], 'a': ['k1=v1'], 'b': ['k2=v2']},
+ # Date assignment
+ {'if': ['.date == 2020-01-01'], 'then': ['.date = 2020-02-02']},
+ # Redundant assignments
+ {'if': ['.account in Income'],
+ 'then': ['.account = Income:Other', '.account = Income:Other']},
+ {'if': ['.number > 0'],
+ 'a': ['.number *= .5', '.number *= .5'],
+ 'b': ['.number *= .5']},
+])
+def test_invalid_rewrite_rule(source):
+ with pytest.raises(ValueError):
+ rewrite.RewriteRule(source)
+
+def test_rewrite_ruleset():
+ account = 'Income:CurrencyConversion'
+ ruleset = rewrite.RewriteRuleset(rewrite.RewriteRule(src) for src in [
+ {'if': ['.account == Expenses:CurrencyConversion'],
+ 'rename': [f'.account = {account}']},
+ {'if': ['project == alpha', '.account != Assets:Cash'],
+ 'cap': ['project = Alpha']},
+ ])
+ txn = testutil.Transaction(project='alpha', postings=[
+ ('Assets:Cash', -20),
+ ('Expenses:CurrencyConversion', 18),
+ ('Expenses:CurrencyConversion', 1, {'project': 'Conservancy'}),
+ ('Expenses:BankingFees', 1),
+ ])
+ posts = ruleset.rewrite(data.Posting.from_txn(txn))
+ post = next(posts)
+ assert post.account == 'Assets:Cash'
+ assert post.meta['project'] == 'alpha'
+ post = next(posts)
+ assert post.account == account
+ # Project not capitalized because the first rule took priority
+ assert post.meta['project'] == 'alpha'
+ post = next(posts)
+ assert post.account == account
+ assert post.meta['project'] == 'Conservancy'
+ post = next(posts)
+ assert post.account == 'Expenses:BankingFees'
+ assert post.meta['project'] == 'Alpha'
+
+def test_ruleset_from_yaml_path():
+ yaml_path = testutil.test_path('userconfig/Rewrites01.yml')
+ assert rewrite.RewriteRuleset.from_yaml(yaml_path)
+
+def test_ruleset_from_yaml_str():
+ with testutil.test_path('userconfig/Rewrites01.yml').open() as yaml_file:
+ yaml_s = yaml_file.read()
+ assert rewrite.RewriteRuleset.from_yaml(yaml_s)
+
+def test_bad_ruleset_yaml_path():
+ yaml_path = testutil.test_path('repository/Projects/project-data.yml')
+ with pytest.raises(ValueError):
+ rewrite.RewriteRuleset.from_yaml(yaml_path)
+
+@pytest.mark.parametrize('source', [
+ # Wrong root objects
+ 1,
+ 2.3,
+ True,
+ None,
+ {},
+ 'string',
+ [{}, 'a'],
+ [{}, ['b']],
+ # Rules have wrong type
+ [{'if': '.account in Equity', 'add': ['testkey = value']}],
+ [{'if': ['.account in Equity'], 'add': 'testkey = value'}],
+])
+def test_bad_ruleset_yaml_str(source):
+ yaml_doc = yaml.safe_dump(source)
+ with pytest.raises(ValueError):
+ rewrite.RewriteRuleset.from_yaml(yaml_doc)
diff --git a/tests/userconfig/Rewrites01.yml b/tests/userconfig/Rewrites01.yml
new file mode 100644
index 0000000..0ea8ddd
--- /dev/null
+++ b/tests/userconfig/Rewrites01.yml
@@ -0,0 +1,4 @@
+- if:
+ - .account == Expenses:CurrencyConversion
+ then:
+ - .account = Income:CurrencyConversion