diff --git a/import2ledger/__main__.py b/import2ledger/__main__.py index 14a8294..9b34130 100644 --- a/import2ledger/__main__.py +++ b/import2ledger/__main__.py @@ -1,5 +1,6 @@ import collections import contextlib +import decimal import logging import sys @@ -87,6 +88,22 @@ def setup_logger(logger, main_config, stream): handler.setFormatter(formatter) logger.addHandler(handler) +def decimal_context(base=decimal.BasicContext): + context = base.copy() + context.rounding = decimal.ROUND_HALF_EVEN + context.traps = { + decimal.Clamped: True, + decimal.DivisionByZero: True, + decimal.FloatOperation: True, + decimal.Inexact: False, + decimal.InvalidOperation: True, + decimal.Overflow: True, + decimal.Rounded: False, + decimal.Subnormal: True, + decimal.Underflow: True, + } + return context + def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr): try: my_config = config.Configuration(arglist) @@ -94,14 +111,15 @@ def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr): my_config.error("{}: {!r}".format(error.strerror, error.user_input)) return 3 setup_logger(logger, my_config, stderr) - importer = FileImporter(my_config, stdout) - failures = 0 - for input_path, error in importer.import_paths(my_config.args.input_paths): - if error is None: - logger.info("%s: imported", input_path) - else: - logger.warning("%s: failed to import: %s", input_path or error.path, error) - failures += 1 + with decimal.localcontext(decimal_context()): + importer = FileImporter(my_config, stdout) + failures = 0 + for input_path, error in importer.import_paths(my_config.args.input_paths): + if error is None: + logger.info("%s: imported", input_path) + else: + logger.warning("%s: failed to import: %s", input_path or error.path, error) + failures += 1 if failures == 0: return 0 else: diff --git a/tests/__init__.py b/tests/__init__.py index 4c5c929..5a8d707 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,11 @@ +import decimal import pathlib import re +from import2ledger import __main__ as i2lmain + +decimal.setcontext(i2lmain.decimal_context()) + DATA_DIR = pathlib.Path(__file__).with_name('data') def normalize_whitespace(s):