config: Add open_output_file method.

Make this functionality accessible to hooks.
This commit is contained in:
Brett Smith 2017-12-31 11:10:49 -05:00
parent efe5768941
commit d2f8772e08
3 changed files with 20 additions and 10 deletions

View file

@ -43,12 +43,7 @@ class FileImporter:
'source_path': in_path.as_posix(),
'source_stem': in_path.stem,
}
with contextlib.ExitStack() as exit_stack:
output_path = self.config.get_output_path()
if output_path is None:
out_file = self.stdout
else:
out_file = exit_stack.enter_context(output_path.open('a'))
with self.config.open_output_file() as out_file:
for importer, template in importers:
in_file.seek(0)
for entry_data in importer(in_file):
@ -106,7 +101,7 @@ def decimal_context(base=decimal.BasicContext):
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
try:
my_config = config.Configuration(arglist)
my_config = config.Configuration(arglist, stdout, stderr)
except errors.UserInputError as error:
my_config.error("{}: {!r}".format(error.strerror, error.user_input))
return 3

View file

@ -28,7 +28,10 @@ class Configuration:
'unsigned_currency_format': '#,##0.### ¤¤',
}
def __init__(self, arglist):
def __init__(self, arglist, stdout, stderr):
self.stdout = stdout
self.stderr = stderr
argparser = self._build_argparser()
self.error = argparser.error
self.args = argparser.parse_args(arglist)
@ -186,6 +189,14 @@ class Configuration:
for secname in self.conffile}
self.date_ranges[default_secname] = self._parse_date_range(default_secname)
@contextlib.contextmanager
def _open_path(self, path, fallback_file, *args, **kwargs):
if path is None:
yield fallback_file
else:
with path.open(*args, **kwargs) as open_file:
yield open_file
@contextlib.contextmanager
def from_section(self, section_name):
prev_section = self.args.use_config
@ -226,6 +237,10 @@ class Configuration:
section_config = self._get_section(section_name)
return self._s_to_path(section_config['output_path'])
def open_output_file(self, section_name=None):
path = self.get_output_path(section_name)
return self._open_path(path, self.stdout, 'a')
def get_template(self, config_key, section_name=None, factory=template.Template):
section_config = self._get_section(section_name)
try:

View file

@ -14,12 +14,12 @@ from import2ledger import config, errors
from . import DATA_DIR
def config_from_file(path, arglist=[]):
def config_from_file(path, arglist=[], stdout=None, stderr=None):
path = pathlib.Path(path)
if not path.is_absolute():
path = DATA_DIR / path
arglist = ['-C', path.as_posix(), *arglist, os.devnull]
return config.Configuration(arglist)
return config.Configuration(arglist, stdout, stderr)
def test_defaults():
config = config_from_file('test_config.ini', ['--sign', 'GBP', '-O', 'out_arg'])