From 0f4f83e0795fd61795e102c77b5e1d5a9d7a5d9c Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Wed, 27 Dec 2017 12:52:47 -0500 Subject: [PATCH] importers._csv: Refactor header reading into CSVImporterBase. --- import2ledger/importers/_csv.py | 27 ++++++++++++++---- import2ledger/importers/benevity.py | 44 ++++++++--------------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/import2ledger/importers/_csv.py b/import2ledger/importers/_csv.py index 754c3fa..368471c 100644 --- a/import2ledger/importers/_csv.py +++ b/import2ledger/importers/_csv.py @@ -20,15 +20,32 @@ class CSVImporterBase: ENTRY_SEED = {} COPIED_FIELDS = {} + @classmethod + def _read_header_row(cls, row): + return {} if len(row) < cls._HEADER_MAX_LEN else None + + @classmethod + def _read_header(cls, input_file): + cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS) + cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS) + header = {} + row = None + for row in csv.reader(input_file): + row_data = cls._read_header_row(row) + if row_data is None: + break + else: + header.update(row_data) + return header, row + @classmethod def can_import(cls, input_file): - in_csv = csv.reader(input_file) - fields = next(iter(in_csv), []) - return cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS).issubset(fields) + _, fields = cls._read_header(input_file) + return cls._NEEDED_KEYS.issubset(fields or ()) def __init__(self, input_file): - self.in_csv = csv.DictReader(input_file) - self.entry_seed = {} + self.entry_seed, fields = self._read_header(input_file) + self.in_csv = csv.DictReader(input_file, fields) def __iter__(self): for row in self.in_csv: diff --git a/import2ledger/importers/benevity.py b/import2ledger/importers/benevity.py index 4e5a34d..c304103 100644 --- a/import2ledger/importers/benevity.py +++ b/import2ledger/importers/benevity.py @@ -1,9 +1,12 @@ -import csv - from . import _csv from .. import strparse class PaymentImporter(_csv.CSVImporterBase): + HEADER_FIELDS = { + 'Currency': 'currency', + 'Disbursement ID': 'disbursement_id', + 'Reference': 'reference', + } DATE_FIELD = 'Date of Donation' NAME_FIELDS = ['Donor First Name', 'Donor Last Name'] DECIMAL_FIELDS = { @@ -28,37 +31,14 @@ class PaymentImporter(_csv.CSVImporterBase): NOT_SHARED = 'Not shared by donor' @classmethod - def _read_header(cls, source): - needed_keys = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS) - header = {} - for row in csv.reader(source): - row_len = len(row) - if row_len < 2: - pass - elif row_len == 2: - header[row[0]] = row[1] - elif needed_keys.issubset(row): - return header, csv.DictReader(source, row) - else: - break - raise ValueError("source is not a Benevity CSV") - - @classmethod - def can_import(cls, input_file): - try: - header, _ = cls._read_header(input_file) - except ValueError: - return False + def _read_header_row(cls, row): + row_len = len(row) + if row_len > 2: + return None + elif row_len == 2 and row[0] in cls.HEADER_FIELDS: + return {cls.HEADER_FIELDS[row[0]]: row[1]} else: - return bool(header) - - def __init__(self, input_file): - header, self.in_csv = self._read_header(input_file) - self.entry_seed = { - 'currency': header['Currency'], - 'disbursement_id': header['Disbursement ID'], - 'reference': header['Payment Reference'], - } + return {} def _read_row(self, row): try: