config: Add open_output_file method.
Make this functionality accessible to hooks.
This commit is contained in:
		
							parent
							
								
									efe5768941
								
							
						
					
					
						commit
						d2f8772e08
					
				
					 3 changed files with 20 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -43,12 +43,7 @@ class FileImporter:
 | 
			
		|||
            'source_path': in_path.as_posix(),
 | 
			
		||||
            'source_stem': in_path.stem,
 | 
			
		||||
        }
 | 
			
		||||
        with contextlib.ExitStack() as exit_stack:
 | 
			
		||||
            output_path = self.config.get_output_path()
 | 
			
		||||
            if output_path is None:
 | 
			
		||||
                out_file = self.stdout
 | 
			
		||||
            else:
 | 
			
		||||
                out_file = exit_stack.enter_context(output_path.open('a'))
 | 
			
		||||
        with self.config.open_output_file() as out_file:
 | 
			
		||||
            for importer, template in importers:
 | 
			
		||||
                in_file.seek(0)
 | 
			
		||||
                for entry_data in importer(in_file):
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +101,7 @@ def decimal_context(base=decimal.BasicContext):
 | 
			
		|||
 | 
			
		||||
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
 | 
			
		||||
    try:
 | 
			
		||||
        my_config = config.Configuration(arglist)
 | 
			
		||||
        my_config = config.Configuration(arglist, stdout, stderr)
 | 
			
		||||
    except errors.UserInputError as error:
 | 
			
		||||
        my_config.error("{}: {!r}".format(error.strerror, error.user_input))
 | 
			
		||||
        return 3
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,7 +28,10 @@ class Configuration:
 | 
			
		|||
        'unsigned_currency_format': '#,##0.### ¤¤',
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, arglist):
 | 
			
		||||
    def __init__(self, arglist, stdout, stderr):
 | 
			
		||||
        self.stdout = stdout
 | 
			
		||||
        self.stderr = stderr
 | 
			
		||||
 | 
			
		||||
        argparser = self._build_argparser()
 | 
			
		||||
        self.error = argparser.error
 | 
			
		||||
        self.args = argparser.parse_args(arglist)
 | 
			
		||||
| 
						 | 
				
			
			@ -186,6 +189,14 @@ class Configuration:
 | 
			
		|||
                            for secname in self.conffile}
 | 
			
		||||
        self.date_ranges[default_secname] = self._parse_date_range(default_secname)
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def _open_path(self, path, fallback_file, *args, **kwargs):
 | 
			
		||||
        if path is None:
 | 
			
		||||
            yield fallback_file
 | 
			
		||||
        else:
 | 
			
		||||
            with path.open(*args, **kwargs) as open_file:
 | 
			
		||||
                yield open_file
 | 
			
		||||
 | 
			
		||||
    @contextlib.contextmanager
 | 
			
		||||
    def from_section(self, section_name):
 | 
			
		||||
        prev_section = self.args.use_config
 | 
			
		||||
| 
						 | 
				
			
			@ -226,6 +237,10 @@ class Configuration:
 | 
			
		|||
        section_config = self._get_section(section_name)
 | 
			
		||||
        return self._s_to_path(section_config['output_path'])
 | 
			
		||||
 | 
			
		||||
    def open_output_file(self, section_name=None):
 | 
			
		||||
        path = self.get_output_path(section_name)
 | 
			
		||||
        return self._open_path(path, self.stdout, 'a')
 | 
			
		||||
 | 
			
		||||
    def get_template(self, config_key, section_name=None, factory=template.Template):
 | 
			
		||||
        section_config = self._get_section(section_name)
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,12 +14,12 @@ from import2ledger import config, errors
 | 
			
		|||
 | 
			
		||||
from . import DATA_DIR
 | 
			
		||||
 | 
			
		||||
def config_from_file(path, arglist=[]):
 | 
			
		||||
def config_from_file(path, arglist=[], stdout=None, stderr=None):
 | 
			
		||||
    path = pathlib.Path(path)
 | 
			
		||||
    if not path.is_absolute():
 | 
			
		||||
        path = DATA_DIR / path
 | 
			
		||||
    arglist = ['-C', path.as_posix(), *arglist, os.devnull]
 | 
			
		||||
    return config.Configuration(arglist)
 | 
			
		||||
    return config.Configuration(arglist, stdout, stderr)
 | 
			
		||||
 | 
			
		||||
def test_defaults():
 | 
			
		||||
    config = config_from_file('test_config.ini', ['--sign', 'GBP', '-O', 'out_arg'])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue