From d2f8772e08fb84d3af968e6e9fc6fc12098c26cc Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Sun, 31 Dec 2017 11:10:49 -0500 Subject: [PATCH] config: Add open_output_file method. Make this functionality accessible to hooks. --- import2ledger/__main__.py | 9 ++------- import2ledger/config.py | 17 ++++++++++++++++- tests/test_config.py | 4 ++-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/import2ledger/__main__.py b/import2ledger/__main__.py index 9b34130..c03c71d 100644 --- a/import2ledger/__main__.py +++ b/import2ledger/__main__.py @@ -43,12 +43,7 @@ class FileImporter: 'source_path': in_path.as_posix(), 'source_stem': in_path.stem, } - with contextlib.ExitStack() as exit_stack: - output_path = self.config.get_output_path() - if output_path is None: - out_file = self.stdout - else: - out_file = exit_stack.enter_context(output_path.open('a')) + with self.config.open_output_file() as out_file: for importer, template in importers: in_file.seek(0) for entry_data in importer(in_file): @@ -106,7 +101,7 @@ def decimal_context(base=decimal.BasicContext): def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr): try: - my_config = config.Configuration(arglist) + my_config = config.Configuration(arglist, stdout, stderr) except errors.UserInputError as error: my_config.error("{}: {!r}".format(error.strerror, error.user_input)) return 3 diff --git a/import2ledger/config.py b/import2ledger/config.py index 989a9e5..33944f2 100644 --- a/import2ledger/config.py +++ b/import2ledger/config.py @@ -28,7 +28,10 @@ class Configuration: 'unsigned_currency_format': '#,##0.### ¤¤', } - def __init__(self, arglist): + def __init__(self, arglist, stdout, stderr): + self.stdout = stdout + self.stderr = stderr + argparser = self._build_argparser() self.error = argparser.error self.args = argparser.parse_args(arglist) @@ -186,6 +189,14 @@ class Configuration: for secname in self.conffile} self.date_ranges[default_secname] = self._parse_date_range(default_secname) + @contextlib.contextmanager + def _open_path(self, path, fallback_file, *args, **kwargs): + if path is None: + yield fallback_file + else: + with path.open(*args, **kwargs) as open_file: + yield open_file + @contextlib.contextmanager def from_section(self, section_name): prev_section = self.args.use_config @@ -226,6 +237,10 @@ class Configuration: section_config = self._get_section(section_name) return self._s_to_path(section_config['output_path']) + def open_output_file(self, section_name=None): + path = self.get_output_path(section_name) + return self._open_path(path, self.stdout, 'a') + def get_template(self, config_key, section_name=None, factory=template.Template): section_config = self._get_section(section_name) try: diff --git a/tests/test_config.py b/tests/test_config.py index 9fe4fda..44e6abe 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,12 +14,12 @@ from import2ledger import config, errors from . import DATA_DIR -def config_from_file(path, arglist=[]): +def config_from_file(path, arglist=[], stdout=None, stderr=None): path = pathlib.Path(path) if not path.is_absolute(): path = DATA_DIR / path arglist = ['-C', path.as_posix(), *arglist, os.devnull] - return config.Configuration(arglist) + return config.Configuration(arglist, stdout, stderr) def test_defaults(): config = config_from_file('test_config.ini', ['--sign', 'GBP', '-O', 'out_arg'])