import2ledger/tests/test_importers.py

80 lines
3 KiB
Python
Raw Normal View History

import csv
import datetime
import decimal
import io
import importlib
import itertools
import pathlib
import shutil
import pytest
import yaml
from import2ledger import importers, strparse
from . import DATA_DIR
def decimal_constructor(loader, node):
value = loader.construct_scalar(node)
return decimal.Decimal(value)
def date_constructor(loader, node):
value = loader.construct_scalar(node)
return datetime.date.fromisoformat(value)
class TestImporters:
Loader = yaml.Loader
Loader.add_constructor('!decimal', decimal_constructor)
Loader.add_constructor('!date', date_constructor)
with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
test_data = yaml.load(yaml_file, Loader=Loader)
for test in test_data:
test['source'] = DATA_DIR / test['source']
module_name, class_name = test['importer'].rsplit('.', 1)
module = importlib.import_module('.' + module_name, 'import2ledger.importers')
test['importer'] = getattr(module, class_name)
@pytest.mark.parametrize('source_path,importer', [
(t['source'], t['importer']) for t in test_data
])
def test_can_import(self, source_path, importer):
with source_path.open() as source_file:
assert importer.can_import(source_file)
@pytest.mark.parametrize('source_path,importer,header_rows,header_cols', [
(t['source'], t['importer'], t['header_rows'], t['header_cols'])
for t in test_data if t.get('header_rows')
])
def test_can_import_squared_csv(self, source_path, importer, header_rows, header_cols):
# Sometimes when we munge spreadsheets by hand (e.g., to filter by
# project) tools like LibreOffice Calc write a "squared" spreadsheet,
# where every row has the same length. This test ensures the results
# are still recognized for import.
with io.StringIO() as squared_file:
csv_writer = csv.writer(squared_file)
with source_path.open() as source_file:
for row in itertools.islice(csv.reader(source_file), header_rows):
padding = [None] * (header_cols - len(row))
csv_writer.writerow(row + padding)
shutil.copyfileobj(source_file, squared_file)
squared_file.seek(0)
assert importer.can_import(squared_file)
@pytest.mark.parametrize('source_path,import_class,expect_results', [
(t['source'], t['importer'], t['expect']) for t in test_data
])
def test_import(self, source_path, import_class, expect_results):
with source_path.open() as source_file:
importer = import_class(source_file)
for actual, expected in itertools.zip_longest(importer, expect_results):
actual['amount'] = strparse.currency_decimal(actual['amount'])
assert actual == expected
def test_loader(self):
all_importers = list(importers.load_all())
for test in self.test_data:
assert test['importer'] in all_importers