import2ledger/tests/test_importers.py
Ben Sturmfels 61b9683743
Add !date and !decimal YAML constructors to avoid unsafe parsing mode
At some point the defaults for pyyaml were switched to safe parsing mode, so
that the previous arbitrary Python YAML tags like
"!!python/object/apply:datetime.date [2017, 9, 1]" no longer work. A better way
is to define our own explicit constructors to avoid unsafe mode.
2025-09-19 17:21:47 +10:00

79 lines
3 KiB
Python

import csv
import datetime
import decimal
import io
import importlib
import itertools
import pathlib
import shutil
import pytest
import yaml
from import2ledger import importers, strparse
from . import DATA_DIR
def decimal_constructor(loader, node):
value = loader.construct_scalar(node)
return decimal.Decimal(value)
def date_constructor(loader, node):
value = loader.construct_scalar(node)
return datetime.date.fromisoformat(value)
class TestImporters:
Loader = yaml.Loader
Loader.add_constructor('!decimal', decimal_constructor)
Loader.add_constructor('!date', date_constructor)
with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
test_data = yaml.load(yaml_file, Loader=Loader)
for test in test_data:
test['source'] = DATA_DIR / test['source']
module_name, class_name = test['importer'].rsplit('.', 1)
module = importlib.import_module('.' + module_name, 'import2ledger.importers')
test['importer'] = getattr(module, class_name)
@pytest.mark.parametrize('source_path,importer', [
(t['source'], t['importer']) for t in test_data
])
def test_can_import(self, source_path, importer):
with source_path.open() as source_file:
assert importer.can_import(source_file)
@pytest.mark.parametrize('source_path,importer,header_rows,header_cols', [
(t['source'], t['importer'], t['header_rows'], t['header_cols'])
for t in test_data if t.get('header_rows')
])
def test_can_import_squared_csv(self, source_path, importer, header_rows, header_cols):
# Sometimes when we munge spreadsheets by hand (e.g., to filter by
# project) tools like LibreOffice Calc write a "squared" spreadsheet,
# where every row has the same length. This test ensures the results
# are still recognized for import.
with io.StringIO() as squared_file:
csv_writer = csv.writer(squared_file)
with source_path.open() as source_file:
for row in itertools.islice(csv.reader(source_file), header_rows):
padding = [None] * (header_cols - len(row))
csv_writer.writerow(row + padding)
shutil.copyfileobj(source_file, squared_file)
squared_file.seek(0)
assert importer.can_import(squared_file)
@pytest.mark.parametrize('source_path,import_class,expect_results', [
(t['source'], t['importer'], t['expect']) for t in test_data
])
def test_import(self, source_path, import_class, expect_results):
with source_path.open() as source_file:
importer = import_class(source_file)
for actual, expected in itertools.zip_longest(importer, expect_results):
actual['amount'] = strparse.currency_decimal(actual['amount'])
assert actual == expected
def test_loader(self):
all_importers = list(importers.load_all())
for test in self.test_data:
assert test['importer'] in all_importers