import datetime
import decimal
import importlib
import itertools
import pathlib
import re

import pytest
import yaml
from import2ledger import importers

from . import DATA_DIR

class TestImporters:
    def _value_converter(value):
        try:
            is_decimal = re.match(r'^[-+]?\d+\.\d+$', value)
        except TypeError:
            is_decimal = False
        if is_decimal:
            return decimal.Decimal
        else:
            return lambda x: x

    def _date(parts_list):
        return datetime.date(*parts_list)

    KEY_CONVERTERS = {
        'date': _date,
    }

    with pathlib.Path(DATA_DIR, 'imports.yml').open() as yaml_file:
        test_data = yaml.load(yaml_file)
    for test in test_data:
        test['source'] = DATA_DIR / test['source']

        module_name, class_name = test['importer'].rsplit('.', 1)
        module = importlib.import_module('.' + module_name, 'import2ledger.importers')
        test['importer'] = getattr(module, class_name)

        for expect_result in test['expect']:
            for key, value in expect_result.items():
                try:
                    convert_func = KEY_CONVERTERS[key]
                except KeyError:
                    convert_func = _value_converter(value)
                expect_result[key] = convert_func(value)

    @pytest.mark.parametrize('source_path,importer', [
        (t['source'], t['importer']) for t in test_data
    ])
    def test_can_import(self, source_path, importer):
        with source_path.open() as source_file:
            assert importer.can_import(source_file)

    @pytest.mark.parametrize('source_path,import_class,expect_results', [
        (t['source'], t['importer'], t['expect']) for t in test_data
    ])
    def test_import(self, source_path, import_class, expect_results):
        with source_path.open() as source_file:
            importer = import_class(source_file)
            for actual, expected in itertools.zip_longest(importer, expect_results):
                actual['amount'] = decimal.Decimal(actual['amount'])
                assert actual == expected

    def test_loader(self):
        all_importers = list(importers.load_all())
        for test in self.test_data:
            assert test['importer'] in all_importers