main: Provide template variables about the file being imported.

This commit is contained in:
Brett Smith 2017-12-19 09:06:24 -05:00
parent 122323853d
commit cc8c03ab62
5 changed files with 55 additions and 18 deletions

View file

@ -76,6 +76,13 @@ date The date of the transaction, in your configured output
format
------------------ ----------------------------------------------------------
payee The name of the transaction payee
------------------ ----------------------------------------------------------
source_abspath The absolute path of the file being imported
------------------ ----------------------------------------------------------
source_name The filename of the file being imported
------------------ ----------------------------------------------------------
source_path The path of the file being imported, as specified on the
command line
================== ==========================================================
Specific importers and hooks may provide additional variables.

View file

@ -1,3 +1,4 @@
import collections
import contextlib
import logging
import sys
@ -13,7 +14,9 @@ class FileImporter:
self.hooks = [hook(config) for hook in hooks.load_all()]
self.stdout = stdout
def import_file(self, in_file):
def import_file(self, in_file, in_path=None):
if in_path is None:
in_path = pathlib.Path(in_file.name)
importers = []
for importer in self.importers:
in_file.seek(0)
@ -31,6 +34,11 @@ class FileImporter:
importers.append((importer, template))
if not importers:
raise errors.UserInputFileError("no importers available", in_file.name)
source_vars = {
'source_abspath': in_path.absolute().as_posix(),
'source_name': in_path.name,
'source_path': in_path.as_posix(),
}
with contextlib.ExitStack() as exit_stack:
output_path = self.config.get_output_path()
if output_path is None:
@ -48,7 +56,8 @@ class FileImporter:
break
else:
del entry_data['_hook_cancel']
print(template.render(**entry_data), file=out_file, end='')
render_vars = collections.ChainMap(entry_data, source_vars)
print(template.render(render_vars), file=out_file, end='')
def import_path(self, in_path):
if in_path is None:
@ -56,7 +65,7 @@ class FileImporter:
with in_path.open(errors='replace') as in_file:
if not in_file.seekable():
raise errors.UserInputFileError("only seekable files are supported", in_path)
return self.import_file(in_file)
return self.import_file(in_file, in_path)
def import_paths(self, path_seq):
for in_path in path_seq:

View file

@ -8,5 +8,7 @@ template patreon cardfees =
Accrued:Accounts Receivable -{amount}
Expenses:Fees:Credit Card {amount}
template patreon svcfees =
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable -{amount}
Expenses:Fundraising {amount}

View file

@ -2,14 +2,18 @@
Accrued:Accounts Receivable $-52.47
Expenses:Fees:Credit Card $52.47
2017/09/01 Patreon
Accrued:Accounts Receivable $-61.73
Expenses:Fundraising $61.73
2017/10/01 Patreon
Accrued:Accounts Receivable $-99.47
Expenses:Fees:Credit Card $99.47
2017/09/01 Patreon
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable $-61.73
Expenses:Fundraising $61.73
2017/10/01 Patreon
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable $-117.03
Expenses:Fundraising $117.03

View file

@ -30,35 +30,50 @@ def iter_entries(in_file):
if lines:
yield ''.join(lines)
def entries2set(in_file):
return set(normalize_whitespace(e) for e in iter_entries(in_file))
def format_entry(entry_s, format_vars):
return normalize_whitespace(entry_s).format_map(format_vars)
def expected_entries(path):
def format_entries(source, format_vars=None):
if format_vars is None:
format_vars = {}
return (format_entry(e, format_vars) for e in iter_entries(source))
def expected_entries(path, format_vars=None):
path = pathlib.Path(path)
if not path.is_absolute():
path = DATA_DIR / path
with path.open() as in_file:
return entries2set(in_file)
return list(format_entries(in_file, format_vars))
def path_vars(path):
return {
'source_abspath': str(path),
'source_name': path.name,
'source_path': str(path),
}
def test_fees_import():
source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
arglist = ARGLIST + [
'-c', 'One',
pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(),
source_path.as_posix(),
]
exitcode, stdout, _ = run_main(arglist)
assert exitcode == 0
actual = entries2set(stdout)
assert actual == expected_entries('test_main_fees_import.ledger')
actual = list(format_entries(stdout))
expected = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
assert actual == expected
def test_date_range_import():
source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
arglist = ARGLIST + [
'-c', 'One',
'--date-range', '2017/10/01-',
pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(),
source_path.as_posix(),
]
exitcode, stdout, _ = run_main(arglist)
assert exitcode == 0
actual = entries2set(stdout)
expected = {entry for entry in expected_entries('test_main_fees_import.ledger')
if entry.startswith('2017/10/')}
actual = list(format_entries(stdout))
valid = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
expected = [entry for entry in valid if entry.startswith('2017/10/')]
assert actual == expected