tests: Support more Decimal values from importers without specifying them all.

This commit is contained in:
Brett Smith 2017-12-17 13:42:51 -05:00
parent e3ec03cf19
commit a924d9fb1f
2 changed files with 25 additions and 14 deletions

View file

@ -3,11 +3,11 @@
expect: expect:
- payee: Alex Jones - payee: Alex Jones
date: [2017, 9, 1] date: [2017, 9, 1]
amount: "150" amount: "150.00"
currency: USD currency: USD
- payee: Dakota Doe - payee: Dakota Doe
date: [2017, 9, 1] date: [2017, 9, 1]
amount: "12" amount: "12.00"
currency: USD currency: USD
- source: PatreonEarnings.csv - source: PatreonEarnings.csv
@ -39,7 +39,7 @@
expect: expect:
- payee: Patreon - payee: Patreon
date: [2017, 9, 1] date: [2017, 9, 1]
amount: "2" amount: "2.00"
currency: USD currency: USD
country_code: AT country_code: AT
country_name: Austria country_name: Austria
@ -67,17 +67,17 @@
expect: expect:
- payee: Dakota Smith - payee: Dakota Smith
date: [2017, 11, 8] date: [2017, 11, 8]
amount: "100" amount: "100.00"
fee: "3" fee: "3.0"
tax: "0" tax: "0.0"
currency: USD currency: USD
payment_id: ch_oxuish6phae2Raighooghi3U payment_id: ch_oxuish6phae2Raighooghi3U
description: "Payment for invoice #102" description: "Payment for invoice #102"
- payee: Dakota Jones - payee: Dakota Jones
date: [2017, 10, 28] date: [2017, 10, 28]
amount: "50" amount: "50.00"
fee: "1.4" fee: "1.4"
tax: "0" tax: "0.0"
currency: USD currency: USD
payment_id: ch_hHee9ef1aeyee1ruo7ochee9 payment_id: ch_hHee9ef1aeyee1ruo7ochee9
description: "Payment for invoice #100" description: "Payment for invoice #100"

View file

@ -3,6 +3,7 @@ import decimal
import importlib import importlib
import itertools import itertools
import pathlib import pathlib
import re
import pytest import pytest
import yaml import yaml
@ -11,13 +12,21 @@ from import2ledger import importers
from . import DATA_DIR from . import DATA_DIR
class TestImporters: class TestImporters:
def _value_converter(value):
try:
is_decimal = re.match(r'^[-+]?\d+\.\d+$', value)
except TypeError:
is_decimal = False
if is_decimal:
return decimal.Decimal
else:
return lambda x: x
def _date(parts_list): def _date(parts_list):
return datetime.date(*parts_list) return datetime.date(*parts_list)
DATA_TYPES = { KEY_CONVERTERS = {
'date': _date, 'date': _date,
'fee': decimal.Decimal,
'tax': decimal.Decimal,
} }
with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file: with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
@ -30,11 +39,12 @@ class TestImporters:
test['importer'] = getattr(module, class_name) test['importer'] = getattr(module, class_name)
for expect_result in test['expect']: for expect_result in test['expect']:
for key, type_func in DATA_TYPES.items(): for key, value in expect_result.items():
try: try:
expect_result[key] = type_func(expect_result[key]) convert_func = KEY_CONVERTERS[key]
except KeyError: except KeyError:
pass convert_func = _value_converter(value)
expect_result[key] = convert_func(value)
@pytest.mark.parametrize('source_path,importer', [ @pytest.mark.parametrize('source_path,importer', [
(t['source'], t['importer']) for t in test_data (t['source'], t['importer']) for t in test_data
@ -50,6 +60,7 @@ class TestImporters:
with source_path.open() as source_file: with source_path.open() as source_file:
importer = import_class(source_file) importer = import_class(source_file)
for actual, expected in itertools.zip_longest(importer, expect_results): for actual, expected in itertools.zip_longest(importer, expect_results):
actual['amount'] = decimal.Decimal(actual['amount'])
assert actual == expected assert actual == expected
def test_loader(self): def test_loader(self):