config: Add open_output_file method.
Make this functionality accessible to hooks.
This commit is contained in:
parent
efe5768941
commit
d2f8772e08
3 changed files with 20 additions and 10 deletions
|
@ -43,12 +43,7 @@ class FileImporter:
|
|||
'source_path': in_path.as_posix(),
|
||||
'source_stem': in_path.stem,
|
||||
}
|
||||
with contextlib.ExitStack() as exit_stack:
|
||||
output_path = self.config.get_output_path()
|
||||
if output_path is None:
|
||||
out_file = self.stdout
|
||||
else:
|
||||
out_file = exit_stack.enter_context(output_path.open('a'))
|
||||
with self.config.open_output_file() as out_file:
|
||||
for importer, template in importers:
|
||||
in_file.seek(0)
|
||||
for entry_data in importer(in_file):
|
||||
|
@ -106,7 +101,7 @@ def decimal_context(base=decimal.BasicContext):
|
|||
|
||||
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
|
||||
try:
|
||||
my_config = config.Configuration(arglist)
|
||||
my_config = config.Configuration(arglist, stdout, stderr)
|
||||
except errors.UserInputError as error:
|
||||
my_config.error("{}: {!r}".format(error.strerror, error.user_input))
|
||||
return 3
|
||||
|
|
|
@ -28,7 +28,10 @@ class Configuration:
|
|||
'unsigned_currency_format': '#,##0.### ¤¤',
|
||||
}
|
||||
|
||||
def __init__(self, arglist):
|
||||
def __init__(self, arglist, stdout, stderr):
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
argparser = self._build_argparser()
|
||||
self.error = argparser.error
|
||||
self.args = argparser.parse_args(arglist)
|
||||
|
@ -186,6 +189,14 @@ class Configuration:
|
|||
for secname in self.conffile}
|
||||
self.date_ranges[default_secname] = self._parse_date_range(default_secname)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _open_path(self, path, fallback_file, *args, **kwargs):
|
||||
if path is None:
|
||||
yield fallback_file
|
||||
else:
|
||||
with path.open(*args, **kwargs) as open_file:
|
||||
yield open_file
|
||||
|
||||
@contextlib.contextmanager
|
||||
def from_section(self, section_name):
|
||||
prev_section = self.args.use_config
|
||||
|
@ -226,6 +237,10 @@ class Configuration:
|
|||
section_config = self._get_section(section_name)
|
||||
return self._s_to_path(section_config['output_path'])
|
||||
|
||||
def open_output_file(self, section_name=None):
|
||||
path = self.get_output_path(section_name)
|
||||
return self._open_path(path, self.stdout, 'a')
|
||||
|
||||
def get_template(self, config_key, section_name=None, factory=template.Template):
|
||||
section_config = self._get_section(section_name)
|
||||
try:
|
||||
|
|
|
@ -14,12 +14,12 @@ from import2ledger import config, errors
|
|||
|
||||
from . import DATA_DIR
|
||||
|
||||
def config_from_file(path, arglist=[]):
|
||||
def config_from_file(path, arglist=[], stdout=None, stderr=None):
|
||||
path = pathlib.Path(path)
|
||||
if not path.is_absolute():
|
||||
path = DATA_DIR / path
|
||||
arglist = ['-C', path.as_posix(), *arglist, os.devnull]
|
||||
return config.Configuration(arglist)
|
||||
return config.Configuration(arglist, stdout, stderr)
|
||||
|
||||
def test_defaults():
|
||||
config = config_from_file('test_config.ini', ['--sign', 'GBP', '-O', 'out_arg'])
|
||||
|
|
Loading…
Reference in a new issue