main: Provide template variables about the file being imported.

This commit is contained in:
Brett Smith 2017-12-19 09:06:24 -05:00
parent 122323853d
commit cc8c03ab62
5 changed files with 55 additions and 18 deletions

View file

@ -76,6 +76,13 @@ date The date of the transaction, in your configured output
format format
------------------ ---------------------------------------------------------- ------------------ ----------------------------------------------------------
payee The name of the transaction payee payee The name of the transaction payee
------------------ ----------------------------------------------------------
source_abspath The absolute path of the file being imported
------------------ ----------------------------------------------------------
source_name The filename of the file being imported
------------------ ----------------------------------------------------------
source_path The path of the file being imported, as specified on the
command line
================== ========================================================== ================== ==========================================================
Specific importers and hooks may provide additional variables. Specific importers and hooks may provide additional variables.

View file

@ -1,3 +1,4 @@
import collections
import contextlib import contextlib
import logging import logging
import sys import sys
@ -13,7 +14,9 @@ class FileImporter:
self.hooks = [hook(config) for hook in hooks.load_all()] self.hooks = [hook(config) for hook in hooks.load_all()]
self.stdout = stdout self.stdout = stdout
def import_file(self, in_file): def import_file(self, in_file, in_path=None):
if in_path is None:
in_path = pathlib.Path(in_file.name)
importers = [] importers = []
for importer in self.importers: for importer in self.importers:
in_file.seek(0) in_file.seek(0)
@ -31,6 +34,11 @@ class FileImporter:
importers.append((importer, template)) importers.append((importer, template))
if not importers: if not importers:
raise errors.UserInputFileError("no importers available", in_file.name) raise errors.UserInputFileError("no importers available", in_file.name)
source_vars = {
'source_abspath': in_path.absolute().as_posix(),
'source_name': in_path.name,
'source_path': in_path.as_posix(),
}
with contextlib.ExitStack() as exit_stack: with contextlib.ExitStack() as exit_stack:
output_path = self.config.get_output_path() output_path = self.config.get_output_path()
if output_path is None: if output_path is None:
@ -48,7 +56,8 @@ class FileImporter:
break break
else: else:
del entry_data['_hook_cancel'] del entry_data['_hook_cancel']
print(template.render(**entry_data), file=out_file, end='') render_vars = collections.ChainMap(entry_data, source_vars)
print(template.render(render_vars), file=out_file, end='')
def import_path(self, in_path): def import_path(self, in_path):
if in_path is None: if in_path is None:
@ -56,7 +65,7 @@ class FileImporter:
with in_path.open(errors='replace') as in_file: with in_path.open(errors='replace') as in_file:
if not in_file.seekable(): if not in_file.seekable():
raise errors.UserInputFileError("only seekable files are supported", in_path) raise errors.UserInputFileError("only seekable files are supported", in_path)
return self.import_file(in_file) return self.import_file(in_file, in_path)
def import_paths(self, path_seq): def import_paths(self, path_seq):
for in_path in path_seq: for in_path in path_seq:

View file

@ -8,5 +8,7 @@ template patreon cardfees =
Accrued:Accounts Receivable -{amount} Accrued:Accounts Receivable -{amount}
Expenses:Fees:Credit Card {amount} Expenses:Fees:Credit Card {amount}
template patreon svcfees = template patreon svcfees =
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable -{amount} Accrued:Accounts Receivable -{amount}
Expenses:Fundraising {amount} Expenses:Fundraising {amount}

View file

@ -2,14 +2,18 @@
Accrued:Accounts Receivable $-52.47 Accrued:Accounts Receivable $-52.47
Expenses:Fees:Credit Card $52.47 Expenses:Fees:Credit Card $52.47
2017/09/01 Patreon
Accrued:Accounts Receivable $-61.73
Expenses:Fundraising $61.73
2017/10/01 Patreon 2017/10/01 Patreon
Accrued:Accounts Receivable $-99.47 Accrued:Accounts Receivable $-99.47
Expenses:Fees:Credit Card $99.47 Expenses:Fees:Credit Card $99.47
2017/09/01 Patreon
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable $-61.73
Expenses:Fundraising $61.73
2017/10/01 Patreon 2017/10/01 Patreon
;SourcePath: {source_abspath}
;SourceName: {source_name}
Accrued:Accounts Receivable $-117.03 Accrued:Accounts Receivable $-117.03
Expenses:Fundraising $117.03 Expenses:Fundraising $117.03

View file

@ -30,35 +30,50 @@ def iter_entries(in_file):
if lines: if lines:
yield ''.join(lines) yield ''.join(lines)
def entries2set(in_file): def format_entry(entry_s, format_vars):
return set(normalize_whitespace(e) for e in iter_entries(in_file)) return normalize_whitespace(entry_s).format_map(format_vars)
def expected_entries(path): def format_entries(source, format_vars=None):
if format_vars is None:
format_vars = {}
return (format_entry(e, format_vars) for e in iter_entries(source))
def expected_entries(path, format_vars=None):
path = pathlib.Path(path) path = pathlib.Path(path)
if not path.is_absolute(): if not path.is_absolute():
path = DATA_DIR / path path = DATA_DIR / path
with path.open() as in_file: with path.open() as in_file:
return entries2set(in_file) return list(format_entries(in_file, format_vars))
def path_vars(path):
return {
'source_abspath': str(path),
'source_name': path.name,
'source_path': str(path),
}
def test_fees_import(): def test_fees_import():
source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
arglist = ARGLIST + [ arglist = ARGLIST + [
'-c', 'One', '-c', 'One',
pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), source_path.as_posix(),
] ]
exitcode, stdout, _ = run_main(arglist) exitcode, stdout, _ = run_main(arglist)
assert exitcode == 0 assert exitcode == 0
actual = entries2set(stdout) actual = list(format_entries(stdout))
assert actual == expected_entries('test_main_fees_import.ledger') expected = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
assert actual == expected
def test_date_range_import(): def test_date_range_import():
source_path = pathlib.Path(DATA_DIR, 'PatreonEarnings.csv')
arglist = ARGLIST + [ arglist = ARGLIST + [
'-c', 'One', '-c', 'One',
'--date-range', '2017/10/01-', '--date-range', '2017/10/01-',
pathlib.Path(DATA_DIR, 'PatreonEarnings.csv').as_posix(), source_path.as_posix(),
] ]
exitcode, stdout, _ = run_main(arglist) exitcode, stdout, _ = run_main(arglist)
assert exitcode == 0 assert exitcode == 0
actual = entries2set(stdout) actual = list(format_entries(stdout))
expected = {entry for entry in expected_entries('test_main_fees_import.ledger') valid = expected_entries('test_main_fees_import.ledger', path_vars(source_path))
if entry.startswith('2017/10/')} expected = [entry for entry in valid if entry.startswith('2017/10/')]
assert actual == expected assert actual == expected