At some point the defaults for pyyaml were switched to safe parsing mode, so that the previous arbitrary Python YAML tags like "!!python/object/apply:datetime.date [2017, 9, 1]" no longer work. A better way is to define our own explicit constructors to avoid unsafe mode.
		
			
				
	
	
		
			79 lines
		
	
	
	
		
			3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			79 lines
		
	
	
	
		
			3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import csv
 | 
						|
import datetime
 | 
						|
import decimal
 | 
						|
import io
 | 
						|
import importlib
 | 
						|
import itertools
 | 
						|
import pathlib
 | 
						|
import shutil
 | 
						|
 | 
						|
import pytest
 | 
						|
import yaml
 | 
						|
from import2ledger import importers, strparse
 | 
						|
 | 
						|
from . import DATA_DIR
 | 
						|
 | 
						|
 | 
						|
def decimal_constructor(loader, node):
 | 
						|
    value = loader.construct_scalar(node)
 | 
						|
    return decimal.Decimal(value)
 | 
						|
 | 
						|
 | 
						|
def date_constructor(loader, node):
 | 
						|
    value = loader.construct_scalar(node)
 | 
						|
    return datetime.date.fromisoformat(value)
 | 
						|
 | 
						|
 | 
						|
class TestImporters:
 | 
						|
    Loader = yaml.Loader
 | 
						|
    Loader.add_constructor('!decimal', decimal_constructor)
 | 
						|
    Loader.add_constructor('!date', date_constructor)
 | 
						|
    with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
 | 
						|
        test_data = yaml.load(yaml_file, Loader=Loader)
 | 
						|
    for test in test_data:
 | 
						|
        test['source'] = DATA_DIR / test['source']
 | 
						|
 | 
						|
        module_name, class_name = test['importer'].rsplit('.', 1)
 | 
						|
        module = importlib.import_module('.' + module_name, 'import2ledger.importers')
 | 
						|
        test['importer'] = getattr(module, class_name)
 | 
						|
 | 
						|
    @pytest.mark.parametrize('source_path,importer', [
 | 
						|
        (t['source'], t['importer']) for t in test_data
 | 
						|
    ])
 | 
						|
    def test_can_import(self, source_path, importer):
 | 
						|
        with source_path.open() as source_file:
 | 
						|
            assert importer.can_import(source_file)
 | 
						|
 | 
						|
    @pytest.mark.parametrize('source_path,importer,header_rows,header_cols', [
 | 
						|
        (t['source'], t['importer'], t['header_rows'], t['header_cols'])
 | 
						|
        for t in test_data if t.get('header_rows')
 | 
						|
    ])
 | 
						|
    def test_can_import_squared_csv(self, source_path, importer, header_rows, header_cols):
 | 
						|
        # Sometimes when we munge spreadsheets by hand (e.g., to filter by
 | 
						|
        # project) tools like LibreOffice Calc write a "squared" spreadsheet,
 | 
						|
        # where every row has the same length.  This test ensures the results
 | 
						|
        # are still recognized for import.
 | 
						|
        with io.StringIO() as squared_file:
 | 
						|
            csv_writer = csv.writer(squared_file)
 | 
						|
            with source_path.open() as source_file:
 | 
						|
                for row in itertools.islice(csv.reader(source_file), header_rows):
 | 
						|
                    padding = [None] * (header_cols - len(row))
 | 
						|
                    csv_writer.writerow(row + padding)
 | 
						|
                shutil.copyfileobj(source_file, squared_file)
 | 
						|
            squared_file.seek(0)
 | 
						|
            assert importer.can_import(squared_file)
 | 
						|
 | 
						|
    @pytest.mark.parametrize('source_path,import_class,expect_results', [
 | 
						|
        (t['source'], t['importer'], t['expect']) for t in test_data
 | 
						|
    ])
 | 
						|
    def test_import(self, source_path, import_class, expect_results):
 | 
						|
        with source_path.open() as source_file:
 | 
						|
            importer = import_class(source_file)
 | 
						|
            for actual, expected in itertools.zip_longest(importer, expect_results):
 | 
						|
                actual['amount'] = strparse.currency_decimal(actual['amount'])
 | 
						|
                assert actual == expected
 | 
						|
 | 
						|
    def test_loader(self):
 | 
						|
        all_importers = list(importers.load_all())
 | 
						|
        for test in self.test_data:
 | 
						|
            assert test['importer'] in all_importers
 |