conservancy_beancount/tests/testutil.py

332 lines
10 KiB
Python
Raw Normal View History

"""Mock Beancount objects for testing"""
# Copyright © 2020 Brett Smith
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import datetime
import itertools
import beancount.core.amount as bc_amount
import beancount.core.data as bc_data
from decimal import Decimal
from pathlib import Path
2020-03-25 04:12:20 +00:00
from conservancy_beancount import rtutil
EXTREME_FUTURE_DATE = datetime.date(datetime.MAXYEAR, 12, 30)
FUTURE_DATE = datetime.date.today() + datetime.timedelta(days=365 * 99)
FY_START_DATE = datetime.date(2020, 3, 1)
FY_MID_DATE = datetime.date(2020, 9, 1)
PAST_DATE = datetime.date(2000, 1, 1)
TESTS_DIR = Path(__file__).parent
def check_post_meta(txn, *expected_meta, default=None):
assert len(txn.postings) == len(expected_meta)
for post, expected in zip(txn.postings, expected_meta):
if not expected:
assert not post.meta
else:
actual = None if post.meta is None else {
key: post.meta.get(key, default) for key in expected
}
assert actual == expected
2020-03-28 18:31:17 +00:00
def combine_values(*value_seqs):
2020-03-29 14:18:51 +00:00
stop = 0
for seq in value_seqs:
try:
stop = max(stop, len(seq))
except TypeError:
pass
2020-03-28 18:31:17 +00:00
return itertools.islice(
zip(*(itertools.cycle(seq) for seq in value_seqs)),
2020-03-29 14:18:51 +00:00
stop,
2020-03-28 18:31:17 +00:00
)
2020-04-22 13:17:58 +00:00
def date_seq(date=FY_MID_DATE, step=1):
while True:
yield date
date += datetime.timedelta(days=step)
def parse_date(s, fmt='%Y-%m-%d'):
return datetime.datetime.strptime(s, fmt).date()
def test_path(s):
if s is None:
return s
s = Path(s)
if not s.is_absolute():
s = TESTS_DIR / s
return s
def Amount(number, currency='USD'):
return bc_amount.Amount(Decimal(number), currency)
2020-04-09 18:13:07 +00:00
def Cost(number, currency='USD', date=FY_MID_DATE, label=None):
return bc_data.Cost(Decimal(number), currency, date, label)
def Posting(account, number,
currency='USD', cost=None, price=None, flag=None,
type_=bc_data.Posting, **meta):
2020-04-09 18:13:07 +00:00
if cost is not None:
cost = Cost(*cost)
if not meta:
meta = None
return type_(
account,
2020-04-09 18:13:07 +00:00
Amount(number, currency),
cost,
price,
flag,
meta,
)
2020-03-29 02:19:49 +00:00
LINK_METADATA_STRINGS = {
'Invoices/304321.pdf',
'rt:123/456',
'rt://ticket/234',
}
NON_LINK_METADATA_STRINGS = {
'',
' ',
' ',
}
NON_STRING_METADATA_VALUES = [
Decimal(5),
FY_MID_DATE,
Amount(50),
Amount(500, None),
]
OPENING_EQUITY_ACCOUNTS = itertools.cycle([
'Equity:Funds:Unrestricted',
'Equity:Funds:Restricted',
'Equity:OpeningBalance',
])
2020-04-12 15:00:41 +00:00
def balance_map(source=None, **kwargs):
# The source and/or kwargs should map currency name strings to
# things you can pass to Decimal (a decimal string, an int, etc.)
# This returns a dict that maps currency name strings to Amount instances.
retval = {}
if source is not None:
retval.update((currency, Amount(number, currency))
for currency, number in source)
if kwargs:
retval.update(balance_map(kwargs.items()))
return retval
class Transaction:
def __init__(self,
date=FY_MID_DATE, flag='*', payee=None,
narration='', tags=None, links=None, postings=None,
**meta):
if isinstance(date, str):
date = parse_date(date)
self.date = date
self.flag = flag
self.payee = payee
self.narration = narration
self.tags = set(tags or '')
self.links = set(links or '')
self.postings = []
self.meta = {
'filename': '<test>',
'lineno': 0,
}
self.meta.update(meta)
if postings is not None:
for posting in postings:
self.add_posting(*posting)
def add_posting(self, arg, *args, **kwargs):
"""Add a posting to this transaction. Use any of these forms:
txn.add_posting(account, number, , kwarg=value, )
txn.add_posting(account, number, , posting_kwargs_dict)
txn.add_posting(posting_object)
"""
if kwargs:
posting = Posting(arg, *args, **kwargs)
elif args:
if isinstance(args[-1], dict):
kwargs = args[-1]
args = args[:-1]
posting = Posting(arg, *args, **kwargs)
else:
posting = arg
self.postings.append(posting)
@classmethod
def opening_balance(cls, acct=None, **txn_meta):
if acct is None:
acct = next(OPENING_EQUITY_ACCOUNTS)
return cls(**txn_meta, postings=[
('Assets:Receivable:Accounts', 100),
('Assets:Receivable:Loans', 200),
('Liabilities:Payable:Accounts', -15),
('Liabilities:Payable:Vacation', -25),
(acct, -260),
])
class TestConfig:
def __init__(self, *,
payment_threshold=0,
2020-03-25 04:12:20 +00:00
repo_path=None,
rt_client=None,
):
self._payment_threshold = Decimal(payment_threshold)
self.repo_path = test_path(repo_path)
2020-03-25 04:12:20 +00:00
self._rt_client = rt_client
if rt_client is None:
self._rt_wrapper = None
else:
self._rt_wrapper = rtutil.RT(rt_client)
def payment_threshold(self):
return self._payment_threshold
def repository_path(self):
return self.repo_path
2020-03-24 13:08:08 +00:00
2020-03-25 04:12:20 +00:00
def rt_client(self):
return self._rt_client
def rt_wrapper(self):
return self._rt_wrapper
2020-03-24 13:08:08 +00:00
class _TicketBuilder:
MESSAGE_ATTACHMENTS = [
('(Unnamed)', 'multipart/alternative', '0b'),
('(Unnamed)', 'text/plain', '1.2k'),
('(Unnamed)', 'text/html', '1.4k'),
]
MISC_ATTACHMENTS = [
('Forwarded Message.eml', 'message/rfc822', '3.1k'),
('photo.jpg', 'image/jpeg', '65.2k'),
('ConservancyInvoice-301.pdf', 'application/pdf', '326k'),
('Company_invoice-2020030405_as-sent.pdf', 'application/pdf', '50k'),
('statement.txt', 'text/plain', '652b'),
('screenshot.png', 'image/png', '1.9m'),
]
def __init__(self):
self.id_seq = itertools.count(1)
self.misc_attchs = itertools.cycle(self.MISC_ATTACHMENTS)
def new_attch(self, attch):
return (str(next(self.id_seq)), *attch)
def new_msg_with_attachments(self, attachments_count=1):
for attch in self.MESSAGE_ATTACHMENTS:
yield self.new_attch(attch)
for _ in range(attachments_count):
yield self.new_attch(next(self.misc_attchs))
def new_messages(self, messages_count, attachments_count=None):
for n in range(messages_count):
if attachments_count is None:
att_count = messages_count - n
else:
att_count = attachments_count
yield from self.new_msg_with_attachments(att_count)
2020-03-24 13:08:08 +00:00
class RTClient:
_builder = _TicketBuilder()
DEFAULT_URL = 'https://example.org/defaultrt/REST/1.0/'
TICKET_DATA = {
'1': list(_builder.new_messages(1, 3)),
'2': list(_builder.new_messages(2, 1)),
'3': list(_builder.new_messages(3, 0)),
}
del _builder
2020-03-24 13:08:08 +00:00
def __init__(self,
url=DEFAULT_URL,
2020-03-24 13:08:08 +00:00
default_login=None,
default_password=None,
proxy=None,
default_queue='General',
skip_login=False,
verify_cert=True,
http_auth=None,
):
self.url = url
if http_auth is None:
self.user = default_login
self.password = default_password
self.auth_method = 'login'
self.login_result = skip_login or None
else:
self.user = http_auth.username
self.password = http_auth.password
self.auth_method = type(http_auth).__name__
self.login_result = True
self.last_login = None
def login(self, login=None, password=None):
if login is None and password is None:
login = self.user
password = self.password
self.login_result = bool(login and password and not password.startswith('bad'))
self.last_login = (login, password, self.login_result)
return self.login_result
def get_attachments(self, ticket_id):
try:
return list(self.TICKET_DATA[str(ticket_id)])
except KeyError:
return None
def get_attachment(self, ticket_id, attachment_id):
try:
att_seq = iter(self.TICKET_DATA[str(ticket_id)])
except KeyError:
2020-03-25 14:50:50 +00:00
return None
att_id = str(attachment_id)
multipart_id = None
for attch in att_seq:
if attch[0] == att_id:
break
elif attch[2].startswith('multipart/'):
multipart_id = attch[0]
else:
return None
tx_id = multipart_id or att_id
if attch[1] == '(Unnamed)':
filename = ''
else:
filename = attch[1]
return {
'id': att_id,
'ContentType': attch[2],
'Filename': filename,
'Transaction': tx_id,
}
def get_ticket(self, ticket_id):
ticket_id_s = str(ticket_id)
if ticket_id_s not in self.TICKET_DATA:
return None
return {
'id': 'ticket/{}'.format(ticket_id_s),
'numerical_id': ticket_id_s,
}