main: Set up a decimal context that safely handles currency.
This commit is contained in:
parent
85b665200c
commit
efe5768941
2 changed files with 31 additions and 8 deletions
|
@ -1,5 +1,6 @@
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import decimal
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -87,6 +88,22 @@ def setup_logger(logger, main_config, stream):
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def decimal_context(base=decimal.BasicContext):
|
||||||
|
context = base.copy()
|
||||||
|
context.rounding = decimal.ROUND_HALF_EVEN
|
||||||
|
context.traps = {
|
||||||
|
decimal.Clamped: True,
|
||||||
|
decimal.DivisionByZero: True,
|
||||||
|
decimal.FloatOperation: True,
|
||||||
|
decimal.Inexact: False,
|
||||||
|
decimal.InvalidOperation: True,
|
||||||
|
decimal.Overflow: True,
|
||||||
|
decimal.Rounded: False,
|
||||||
|
decimal.Subnormal: True,
|
||||||
|
decimal.Underflow: True,
|
||||||
|
}
|
||||||
|
return context
|
||||||
|
|
||||||
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
|
def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
|
||||||
try:
|
try:
|
||||||
my_config = config.Configuration(arglist)
|
my_config = config.Configuration(arglist)
|
||||||
|
@ -94,14 +111,15 @@ def main(arglist=None, stdout=sys.stdout, stderr=sys.stderr):
|
||||||
my_config.error("{}: {!r}".format(error.strerror, error.user_input))
|
my_config.error("{}: {!r}".format(error.strerror, error.user_input))
|
||||||
return 3
|
return 3
|
||||||
setup_logger(logger, my_config, stderr)
|
setup_logger(logger, my_config, stderr)
|
||||||
importer = FileImporter(my_config, stdout)
|
with decimal.localcontext(decimal_context()):
|
||||||
failures = 0
|
importer = FileImporter(my_config, stdout)
|
||||||
for input_path, error in importer.import_paths(my_config.args.input_paths):
|
failures = 0
|
||||||
if error is None:
|
for input_path, error in importer.import_paths(my_config.args.input_paths):
|
||||||
logger.info("%s: imported", input_path)
|
if error is None:
|
||||||
else:
|
logger.info("%s: imported", input_path)
|
||||||
logger.warning("%s: failed to import: %s", input_path or error.path, error)
|
else:
|
||||||
failures += 1
|
logger.warning("%s: failed to import: %s", input_path or error.path, error)
|
||||||
|
failures += 1
|
||||||
if failures == 0:
|
if failures == 0:
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
|
import decimal
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from import2ledger import __main__ as i2lmain
|
||||||
|
|
||||||
|
decimal.setcontext(i2lmain.decimal_context())
|
||||||
|
|
||||||
DATA_DIR = pathlib.Path(__file__).with_name('data')
|
DATA_DIR = pathlib.Path(__file__).with_name('data')
|
||||||
|
|
||||||
def normalize_whitespace(s):
|
def normalize_whitespace(s):
|
||||||
|
|
Loading…
Reference in a new issue