importers._csv: Refactor header reading into CSVImporterBase.
This commit is contained in:
		
							parent
							
								
									1b1e2d038c
								
							
						
					
					
						commit
						0f4f83e079
					
				
					 2 changed files with 34 additions and 37 deletions
				
			
		| 
						 | 
				
			
			@ -20,15 +20,32 @@ class CSVImporterBase:
 | 
			
		|||
    ENTRY_SEED = {}
 | 
			
		||||
    COPIED_FIELDS = {}
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _read_header_row(cls, row):
 | 
			
		||||
        return {} if len(row) < cls._HEADER_MAX_LEN else None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _read_header(cls, input_file):
 | 
			
		||||
        cls._NEEDED_KEYS = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
 | 
			
		||||
        cls._HEADER_MAX_LEN = len(cls._NEEDED_KEYS)
 | 
			
		||||
        header = {}
 | 
			
		||||
        row = None
 | 
			
		||||
        for row in csv.reader(input_file):
 | 
			
		||||
            row_data = cls._read_header_row(row)
 | 
			
		||||
            if row_data is None:
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                header.update(row_data)
 | 
			
		||||
        return header, row
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def can_import(cls, input_file):
 | 
			
		||||
        in_csv = csv.reader(input_file)
 | 
			
		||||
        fields = next(iter(in_csv), [])
 | 
			
		||||
        return cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS).issubset(fields)
 | 
			
		||||
        _, fields = cls._read_header(input_file)
 | 
			
		||||
        return cls._NEEDED_KEYS.issubset(fields or ())
 | 
			
		||||
 | 
			
		||||
    def __init__(self, input_file):
 | 
			
		||||
        self.in_csv = csv.DictReader(input_file)
 | 
			
		||||
        self.entry_seed = {}
 | 
			
		||||
        self.entry_seed, fields = self._read_header(input_file)
 | 
			
		||||
        self.in_csv = csv.DictReader(input_file, fields)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for row in self.in_csv:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,12 @@
 | 
			
		|||
import csv
 | 
			
		||||
 | 
			
		||||
from . import _csv
 | 
			
		||||
from .. import strparse
 | 
			
		||||
 | 
			
		||||
class PaymentImporter(_csv.CSVImporterBase):
 | 
			
		||||
    HEADER_FIELDS = {
 | 
			
		||||
        'Currency': 'currency',
 | 
			
		||||
        'Disbursement ID': 'disbursement_id',
 | 
			
		||||
        'Reference': 'reference',
 | 
			
		||||
    }
 | 
			
		||||
    DATE_FIELD = 'Date of Donation'
 | 
			
		||||
    NAME_FIELDS = ['Donor First Name', 'Donor Last Name']
 | 
			
		||||
    DECIMAL_FIELDS = {
 | 
			
		||||
| 
						 | 
				
			
			@ -28,37 +31,14 @@ class PaymentImporter(_csv.CSVImporterBase):
 | 
			
		|||
    NOT_SHARED = 'Not shared by donor'
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _read_header(cls, source):
 | 
			
		||||
        needed_keys = cls.NEEDED_FIELDS.union(cls.COPIED_FIELDS)
 | 
			
		||||
        header = {}
 | 
			
		||||
        for row in csv.reader(source):
 | 
			
		||||
            row_len = len(row)
 | 
			
		||||
            if row_len < 2:
 | 
			
		||||
                pass
 | 
			
		||||
            elif row_len == 2:
 | 
			
		||||
                header[row[0]] = row[1]
 | 
			
		||||
            elif needed_keys.issubset(row):
 | 
			
		||||
                return header, csv.DictReader(source, row)
 | 
			
		||||
            else:
 | 
			
		||||
                break
 | 
			
		||||
        raise ValueError("source is not a Benevity CSV")
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def can_import(cls, input_file):
 | 
			
		||||
        try:
 | 
			
		||||
            header, _ = cls._read_header(input_file)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            return False
 | 
			
		||||
    def _read_header_row(cls, row):
 | 
			
		||||
        row_len = len(row)
 | 
			
		||||
        if row_len > 2:
 | 
			
		||||
            return None
 | 
			
		||||
        elif row_len == 2 and row[0] in cls.HEADER_FIELDS:
 | 
			
		||||
            return {cls.HEADER_FIELDS[row[0]]: row[1]}
 | 
			
		||||
        else:
 | 
			
		||||
            return bool(header)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, input_file):
 | 
			
		||||
        header, self.in_csv = self._read_header(input_file)
 | 
			
		||||
        self.entry_seed = {
 | 
			
		||||
            'currency': header['Currency'],
 | 
			
		||||
            'disbursement_id': header['Disbursement ID'],
 | 
			
		||||
            'reference': header['Payment Reference'],
 | 
			
		||||
        }
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
    def _read_row(self, row):
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue