importers: Refactor out a base CSV importer class.

I'm going to build the Stripe importer on top of this.
This commit is contained in:
Brett Smith 2017-11-09 13:06:06 -05:00
parent 3b821cbbee
commit f8a68c3a2e
2 changed files with 65 additions and 35 deletions

View file

@ -0,0 +1,42 @@
import csv
class CSVImporterBase:
"""Common base class for importing CSV files.
Subclasses must define the following:
* TEMPLATE_KEY: A string, as usual
* NEEDED_FIELDS: A set of columns that must exist in the CSV file for
this class to import it.
* _read_row(self, row): A method that returns an entry data dict, or None
if there's nothing to import from this row.
Subclasses may define the following:
* ENTRY_SEED: A dict with the initial entry data.
* COPIED_FIELDS: A dict that maps column names to data keys. These fields
will be copied directly to the entry data dict before _read_row is called.
Fields named here must exist in the CSV for it to be imported.
"""
ENTRY_SEED = {}
COPIED_FIELDS = {}
@classmethod
def can_import(cls, input_file):
in_csv = csv.reader(input_file)
fields = next(iter(in_csv), [])
return cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS).issubset(fields)
def __init__(self, input_file):
self.in_csv = csv.DictReader(input_file)
self.entry_seed = self.ENTRY_SEED.copy()
def __iter__(self):
for row in self.in_csv:
row_data = self._read_row(row)
if row_data is not None:
retval = self.entry_seed.copy()
retval.update(
(entry_key, row[row_key])
for row_key, entry_key in self.COPIED_FIELDS.items()
)
retval.update(row_data)
yield retval

View file

@ -1,37 +1,21 @@
import csv
import datetime
import pathlib
import re
from . import _csv
from .. import util
class ImporterBase:
@classmethod
def can_import(cls, input_file):
in_csv = csv.reader(input_file)
fields = next(iter(in_csv), [])
return cls.NEEDED_FIELDS.issubset(fields)
def __init__(self, input_file):
self.in_csv = csv.DictReader(input_file)
self.start_data = {'currency': 'USD'}
def __iter__(self):
for row in self.in_csv:
row_data = self._read_row(row)
if row_data is not None:
retval = self.start_data.copy()
retval.update(row_data)
yield retval
class IncomeImporter(ImporterBase):
class IncomeImporter(_csv.CSVImporterBase):
NEEDED_FIELDS = frozenset([
'FirstName',
'LastName',
'Pledge',
'Status',
])
COPIED_FIELDS = {
'Pledge': 'amount',
}
ENTRY_SEED = {
'currency': 'USD',
}
TEMPLATE_KEY = 'template patreon income'
def __init__(self, input_file):
@ -39,28 +23,28 @@ class IncomeImporter(ImporterBase):
match = re.search(r'(?:\b|_)(\d{4}-\d{2}-\d{2})(?:\b|_)',
pathlib.Path(input_file.name).name)
if match:
self.start_data['date'] = util.strpdate(match.group(1), '%Y-%m-%d')
self.entry_seed['date'] = util.strpdate(match.group(1), '%Y-%m-%d')
def _read_row(self, row):
if row['Status'] != 'Processed':
return None
else:
return {
'amount': row['Pledge'],
'payee': '{0[FirstName]} {0[LastName]}'.format(row),
}
class FeeImporterBase(ImporterBase):
def _read_row(self, row):
retval = {
key.lower().replace(' ', '_'): row[key]
for key in self.NEEDED_FIELDS.difference([self.AMOUNT_FIELD, 'Month'])
class FeeImporterBase(_csv.CSVImporterBase):
ENTRY_SEED = {
'currency': 'USD',
'payee': "Patreon",
}
def _read_row(self, row):
return {
'amount': row[self.AMOUNT_FIELD].lstrip('$'),
'date': util.strpdate(row['Month'], '%Y-%m'),
}
retval['amount'] = row[self.AMOUNT_FIELD].lstrip('$')
retval['date'] = util.strpdate(row['Month'], '%Y-%m')
retval['payee'] = "Patreon"
return retval
class PatreonFeeImporter(FeeImporterBase):
@ -77,5 +61,9 @@ class CardFeeImporter(FeeImporterBase):
class VATImporter(FeeImporterBase):
AMOUNT_FIELD = 'Vat Charged'
NEEDED_FIELDS = frozenset(['Country Code', 'Country Name', 'Month', AMOUNT_FIELD])
NEEDED_FIELDS = frozenset(['Month', AMOUNT_FIELD])
COPIED_FIELDS = {
'Country Code': 'country_code',
'Country Name': 'country_name',
}
TEMPLATE_KEY = 'template patreon vat'