main: Provide template variables about the file being imported.
This commit is contained in:
		
							parent
							
								
									122323853d
								
							
						
					
					
						commit
						cc8c03ab62
					
				
					 5 changed files with 55 additions and 18 deletions
				
			
		| 
						 | 
				
			
			@ -76,6 +76,13 @@ date               The date of the transaction, in your configured output
 | 
			
		|||
                   format
 | 
			
		||||
------------------ ----------------------------------------------------------
 | 
			
		||||
payee              The name of the transaction payee
 | 
			
		||||
------------------ ----------------------------------------------------------
 | 
			
		||||
source_abspath     The absolute path of the file being imported
 | 
			
		||||
------------------ ----------------------------------------------------------
 | 
			
		||||
source_name        The filename of the file being imported
 | 
			
		||||
------------------ ----------------------------------------------------------
 | 
			
		||||
source_path        The path of the file being imported, as specified on the
 | 
			
		||||
                   command line
 | 
			
		||||
================== ==========================================================
 | 
			
		||||
 | 
			
		||||
Specific importers and hooks may provide additional variables.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import logging
 | 
			
		||||
import sys
 | 
			
		||||
| 
						 | 
				
			
			@ -13,7 +14,9 @@ class FileImporter:
 | 
			
		|||
        self.hooks = [hook(config) for hook in hooks.load_all()]
 | 
			
		||||
        self.stdout = stdout
 | 
			
		||||
 | 
			
		||||
    def import_file(self, in_file):
 | 
			
		||||
    def import_file(self, in_file, in_path=None):
 | 
			
		||||
        if in_path is None:
 | 
			
		||||
            in_path = pathlib.Path(in_file.name)
 | 
			
		||||
        importers = []
 | 
			
		||||
        for importer in self.importers:
 | 
			
		||||
            in_file.seek(0)
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +34,11 @@ class FileImporter:
 | 
			
		|||
                    importers.append((importer, template))
 | 
			
		||||
        if not importers:
 | 
			
		||||
            raise errors.UserInputFileError("no importers available", in_file.name)
 | 
			
		||||
        source_vars = {
 | 
			
		||||
            'source_abspath': in_path.absolute().as_posix(),
 | 
			
		||||
            'source_name': in_path.name,
 | 
			
		||||
            'source_path': in_path.as_posix(),
 | 
			
		||||
        }
 | 
			
		||||
        with contextlib.ExitStack() as exit_stack:
 | 
			
		||||
            output_path = self.config.get_output_path()
 | 
			
		||||
            if output_path is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +56,8 @@ class FileImporter:
 | 
			
		|||
                            break
 | 
			
		||||
                    else:
 | 
			
		||||
                        del entry_data['_hook_cancel']
 | 
			
		||||
                        print(template.render(**entry_data), file=out_file, end='')
 | 
			
		||||
                        render_vars = collections.ChainMap(entry_data, source_vars)
 | 
			
		||||
                        print(template.render(render_vars), file=out_file, end='')
 | 
			
		||||
 | 
			
		||||
    def import_path(self, in_path):
 | 
			
		||||
        if in_path is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -56,7 +65,7 @@ class FileImporter:
 | 
			
		|||
        with in_path.open(errors='replace') as in_file:
 | 
			
		||||
            if not in_file.seekable():
 | 
			
		||||
                raise errors.UserInputFileError("only seekable files are supported", in_path)
 | 
			
		||||
            return self.import_file(in_file)
 | 
			
		||||
            return self.import_file(in_file, in_path)
 | 
			
		||||
 | 
			
		||||
    def import_paths(self, path_seq):
 | 
			
		||||
        for in_path in path_seq:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,5 +8,7 @@ template patreon cardfees =
 | 
			
		|||
 Accrued:Accounts Receivable  -{amount}
 | 
			
		||||
 Expenses:Fees:Credit Card  {amount}
 | 
			
		||||
template patreon svcfees =
 | 
			
		||||
 ;SourcePath: {source_abspath}
 | 
			
		||||
 ;SourceName: {source_name}
 | 
			
		||||
 Accrued:Accounts Receivable  -{amount}
 | 
			
		||||
 Expenses:Fundraising  {amount}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,14 +2,18 @@
 | 
			
		|||
  Accrued:Accounts Receivable  $-52.47
 | 
			
		||||
  Expenses:Fees:Credit Card  $52.47
 | 
			
		||||
 | 
			
		||||
2017/09/01 Patreon
 | 
			
		||||
  Accrued:Accounts Receivable  $-61.73
 | 
			
		||||
  Expenses:Fundraising  $61.73
 | 
			
		||||
 | 
			
		||||
2017/10/01 Patreon
 | 
			
		||||
  Accrued:Accounts Receivable  $-99.47
 | 
			
		||||
  Expenses:Fees:Credit Card  $99.47
 | 
			
		||||
 | 
			
		||||
2017/09/01 Patreon
 | 
			
		||||
  ;SourcePath: {source_abspath}
 | 
			
		||||
  ;SourceName: {source_name}
 | 
			
		||||
  Accrued:Accounts Receivable  $-61.73
 | 
			
		||||
  Expenses:Fundraising  $61.73
 | 
			
		||||
 | 
			
		||||
2017/10/01 Patreon
 | 
			
		||||
  ;SourcePath: {source_abspath}
 | 
			
		||||
  ;SourceName: {source_name}
 | 
			
		||||
  Accrued:Accounts Receivable  $-117.03
 | 
			
		||||
  Expenses:Fundraising  $117.03
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,35 +30,50 @@ def iter_entries(in_file):
 | 
			
		|||
    if lines:
 | 
			
		||||
        yield ''.join(lines)
 | 
			
		||||
 | 
			
		||||
def entries2set(in_file):
 | 
			
		||||
    return set(normalize_whitespace(e) for e in iter_entries(in_file))
 | 
			
		||||
def format_entry(entry_s, format_vars):
 | 
			
		||||
    return normalize_whitespace(entry_s).format_map(format_vars)
 | 
			
		||||
 | 
			
		||||
def expected_entries(path):
 | 
			
		||||
def format_entries(source, format_vars=None):
 | 
			
		||||
    if format_vars is None:
 | 
			
		||||
        format_vars = {}
 | 
			
		||||
    return (format_entry(e, format_vars) for e in iter_entries(source))
 | 
			
		||||
 | 
			
		||||
def expected_entries(path, format_vars=None):
 | 
			
		||||
    path = pathlib.Path(path)
 | 
			
		||||
    if not path.is_absolute():
 | 
			
		||||
        path = DATA_DIR / path
 | 
			
		||||
    with path.open() as in_file:
 | 
			
		||||
        return entries2set(in_file)
 | 
			
		||||
        return list(format_entries(in_file, format_vars))
 | 
			
		||||
 | 
			
		||||
def path_vars(path):
 | 
			
		||||
    return {
 | 
			
		||||
        'source_abspath': str(path),
 | 
			
		||||
        'source_name': path.name,
 | 
			
		||||
        'source_path': str(path),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
def test_fees_import():
 | 
			
		||||
    source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
 | 
			
		||||
    arglist = ARGLIST + [
 | 
			
		||||
        '-c', 'One',
 | 
			
		||||
        pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(),
 | 
			
		||||
        source_path.as_posix(),
 | 
			
		||||
    ]
 | 
			
		||||
    exitcode, stdout, _ = run_main(arglist)
 | 
			
		||||
    assert exitcode == 0
 | 
			
		||||
    actual = entries2set(stdout)
 | 
			
		||||
    assert actual == expected_entries('test_main_fees_import.ledger')
 | 
			
		||||
    actual = list(format_entries(stdout))
 | 
			
		||||
    expected = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
 | 
			
		||||
    assert actual == expected
 | 
			
		||||
 | 
			
		||||
def test_date_range_import():
 | 
			
		||||
    source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
 | 
			
		||||
    arglist = ARGLIST + [
 | 
			
		||||
        '-c', 'One',
 | 
			
		||||
        '--date-range', '2017/10/01-',
 | 
			
		||||
        pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(),
 | 
			
		||||
        source_path.as_posix(),
 | 
			
		||||
    ]
 | 
			
		||||
    exitcode, stdout, _ = run_main(arglist)
 | 
			
		||||
    assert exitcode == 0
 | 
			
		||||
    actual = entries2set(stdout)
 | 
			
		||||
    expected = {entry for entry in expected_entries('test_main_fees_import.ledger')
 | 
			
		||||
                if entry.startswith('2017/10/')}
 | 
			
		||||
    actual = list(format_entries(stdout))
 | 
			
		||||
    valid = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
 | 
			
		||||
    expected = [entry for entry in valid if entry.startswith('2017/10/')]
 | 
			
		||||
    assert actual == expected
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue