main: Set up a decimal context that safely handles currency.

This commit is contained in:
Brett Smith 2017-12-31 10:06:57 -05:00
parent 85b665200c
commit efe5768941
2 changed files with 31 additions and 8 deletions

View file

@ -1,5 +1,6 @@
import collections import collections
import contextlib import contextlib
import decimal
import logging import logging
import sys import sys
@ -87,6 +88,22 @@ def setup_logger(logger, main_config, stream):
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
def decimal_context(base=decimal.BasicContext):
context = base.copy()
context.rounding = decimal.ROUND_HALF_EVEN
context.traps = {
decimal.Clamped: True,
decimal.DivisionByZero: True,
decimal.FloatOperation: True,
decimal.Inexact: False,
decimal.InvalidOperation: True,
decimal.Overflow: True,
decimal.Rounded: False,
decimal.Subnormal: True,
decimal.Underflow: True,
}
return context
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr): def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
try: try:
my_config = config.Configuration(arglist) my_config = config.Configuration(arglist)
@ -94,6 +111,7 @@ def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
my_config.error("{}: {!r}".format(error.strerror, error.user_input)) my_config.error("{}: {!r}".format(error.strerror, error.user_input))
return 3 return 3
setup_logger(logger, my_config, stderr) setup_logger(logger, my_config, stderr)
with decimal.localcontext(decimal_context()):
importer = FileImporter(my_config, stdout) importer = FileImporter(my_config, stdout)
failures = 0 failures = 0
for input_path, error in importer.import_paths(my_config.args.input_paths): for input_path, error in importer.import_paths(my_config.args.input_paths):

View file

@ -1,6 +1,11 @@
import decimal
import pathlib import pathlib
import re import re
from import2ledger import __main__ as i2lmain
decimal.setcontext(i2lmain.decimal_context())
DATA_DIR = pathlib.Path(__file__).with_name('data') DATA_DIR = pathlib.Path(__file__).with_name('data')
def normalize_whitespace(s): def normalize_whitespace(s):