tests: Support more Decimal values from importers without specifying them all.

This commit is contained in:
Brett Smith 2017-12-17 13:42:51 -05:00
parent e3ec03cf19
commit a924d9fb1f
2 changed files with 25 additions and 14 deletions

View file

@ -3,11 +3,11 @@
expect:
- payee: Alex Jones
date: [2017, 9, 1]
amount: "150"
amount: "150.00"
currency: USD
- payee: Dakota Doe
date: [2017, 9, 1]
amount: "12"
amount: "12.00"
currency: USD
- source: PatreonEarnings.csv
@ -39,7 +39,7 @@
expect:
- payee: Patreon
date: [2017, 9, 1]
amount: "2"
amount: "2.00"
currency: USD
country_code: AT
country_name: Austria
@ -67,17 +67,17 @@
expect:
- payee: Dakota Smith
date: [2017, 11, 8]
amount: "100"
fee: "3"
tax: "0"
amount: "100.00"
fee: "3.0"
tax: "0.0"
currency: USD
payment_id: ch_oxuish6phae2Raighooghi3U
description: "Payment for invoice #102"
- payee: Dakota Jones
date: [2017, 10, 28]
amount: "50"
amount: "50.00"
fee: "1.4"
tax: "0"
tax: "0.0"
currency: USD
payment_id: ch_hHee9ef1aeyee1ruo7ochee9
description: "Payment for invoice #100"

View file

@ -3,6 +3,7 @@ import decimal
import importlib
import itertools
import pathlib
import re
import pytest
import yaml
@ -11,13 +12,21 @@ from import2ledger import importers
from . import DATA_DIR
class TestImporters:
def _value_converter(value):
try:
is_decimal = re.match(r'^[-+]?\d+\.\d+$', value)
except TypeError:
is_decimal = False
if is_decimal:
return decimal.Decimal
else:
return lambda x: x
def _date(parts_list):
return datetime.date(*parts_list)
DATA_TYPES = {
KEY_CONVERTERS = {
'date': _date,
'fee': decimal.Decimal,
'tax': decimal.Decimal,
}
with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
@ -30,11 +39,12 @@ class TestImporters:
test['importer'] = getattr(module, class_name)
for expect_result in test['expect']:
for key, type_func in DATA_TYPES.items():
for key, value in expect_result.items():
try:
expect_result[key] = type_func(expect_result[key])
convert_func = KEY_CONVERTERS[key]
except KeyError:
pass
convert_func = _value_converter(value)
expect_result[key] = convert_func(value)
@pytest.mark.parametrize('source_path,importer', [
(t['source'], t['importer']) for t in test_data
@ -50,6 +60,7 @@ class TestImporters:
with source_path.open() as source_file:
importer = import_class(source_file)
for actual, expected in itertools.zip_longest(importer, expect_results):
actual['amount'] = decimal.Decimal(actual['amount'])
assert actual == expected
def test_loader(self):