importers._csv: Refactor header reading into CSVImporterBase.

This commit is contained in:
Brett Smith 2017-12-27 12:52:47 -05:00
parent 1b1e2d038c
commit 0f4f83e079
2 changed files with 34 additions and 37 deletions

View file

@ -20,15 +20,32 @@ class CSVImporterBase:
ENTRY_SEED = {}
COPIED_FIELDS = {}
@classmethod
def _read_header_row(cls, row):
return {} if len(row) < cls._HEADER_MAX_LEN else None
@classmethod
def _read_header(cls, input_file):
cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS)
header = {}
row = None
for row in csv.reader(input_file):
row_data = cls._read_header_row(row)
if row_data is None:
break
else:
header.update(row_data)
return header, row
@classmethod
def can_import(cls, input_file):
in_csv = csv.reader(input_file)
fields = next(iter(in_csv), [])
return cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS).issubset(fields)
_, fields = cls._read_header(input_file)
return cls._NEEDED_KEYS.issubset(fields or ())
def __init__(self, input_file):
self.in_csv = csv.DictReader(input_file)
self.entry_seed = {}
self.entry_seed, fields = self._read_header(input_file)
self.in_csv = csv.DictReader(input_file, fields)
def __iter__(self):
for row in self.in_csv:

View file

@ -1,9 +1,12 @@
import csv
from . import _csv
from .. import strparse
class PaymentImporter(_csv.CSVImporterBase):
HEADER_FIELDS = {
'Currency': 'currency',
'Disbursement ID': 'disbursement_id',
'Reference': 'reference',
}
DATE_FIELD = 'Date of Donation'
NAME_FIELDS = ['Donor First Name', 'Donor Last Name']
DECIMAL_FIELDS = {
@ -28,37 +31,14 @@ class PaymentImporter(_csv.CSVImporterBase):
NOT_SHARED = 'Not shared by donor'
@classmethod
def _read_header(cls, source):
needed_keys = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
header = {}
for row in csv.reader(source):
row_len = len(row)
if row_len < 2:
pass
elif row_len == 2:
header[row[0]] = row[1]
elif needed_keys.issubset(row):
return header, csv.DictReader(source, row)
else:
break
raise ValueError("source is not a Benevity CSV")
@classmethod
def can_import(cls, input_file):
try:
header, _ = cls._read_header(input_file)
except ValueError:
return False
def _read_header_row(cls, row):
row_len = len(row)
if row_len > 2:
return None
elif row_len == 2 and row[0] in cls.HEADER_FIELDS:
return {cls.HEADER_FIELDS[row[0]]: row[1]}
else:
return bool(header)
def __init__(self, input_file):
header, self.in_csv = self._read_header(input_file)
self.entry_seed = {
'currency': header['Currency'],
'disbursement_id': header['Disbursement ID'],
'reference': header['Payment Reference'],
}
return {}
def _read_row(self, row):
try: