reconcile: Add type checking information to new prototype reconcilers.

This commit is contained in:
Ben Sturmfels 2022-02-04 15:03:23 +11:00
parent 43548a1ac9
commit ed0bc469ce
Signed by: bsturmfels
GPG key ID: 023C05E2C9C068F0
2 changed files with 17 additions and 13 deletions

View file

@ -20,19 +20,21 @@ import datetime
import io import io
import tempfile import tempfile
import textwrap import textwrap
import typing
from typing import List
import os import os
from beancount import loader from beancount import loader
from beancount.query.query import run_query from beancount.query.query import run_query
def end_of_month(date): def end_of_month(date: datetime.date) -> datetime.date:
"""Given a date, return the last day of the month.""" """Given a date, return the last day of the month."""
# Using 'day' replaces, rather than adds. # Using 'day' replaces, rather than adds.
return date + relativedelta(day=31) return date + relativedelta(day=31)
def format_record_for_grep(row, homedir): def format_record_for_grep(row: typing.List, homedir: str) -> typing.List:
"""Return a line in a grep-style. """Return a line in a grep-style.
This is so the line can be fed into Emacs grep-mode for quickly jumping to This is so the line can be fed into Emacs grep-mode for quickly jumping to
@ -42,7 +44,7 @@ def format_record_for_grep(row, homedir):
return [f'{file}:{row[1]}:'] + row[2:] return [f'{file}:{row[1]}:'] + row[2:]
def max_column_widths(rows): def max_column_widths(rows: List) -> List[int]:
"""Return the max width for each column in a table of data.""" """Return the max width for each column in a table of data."""
if not rows: if not rows:
return [] return []
@ -55,7 +57,7 @@ def max_column_widths(rows):
return maxes return maxes
def tabulate(rows, headers=None): def tabulate(rows: List, headers: List=None) -> str:
"""Format a table of data as a string. """Format a table of data as a string.
Implemented here to avoid adding dependency on "tabulate" package. Implemented here to avoid adding dependency on "tabulate" package.
@ -101,8 +103,9 @@ else:
if not (args.cur_end_date and args.prev_end_date): if not (args.cur_end_date and args.prev_end_date):
parser.error(' --prev-end-date and --cur-end-date must be used together') parser.error(' --prev-end-date and --cur-end-date must be used together')
preDate = args.prev_end_date preDate = args.prev_end_date
lastDateInPeriod = args.cur_end_date lastDateInPeriod = args.cur_end_date.isoformat()
month = lastDateInPeriod.strftime('%Y-%m') month = args.cur_end_date.strftime('%Y-%m')
grep_output_file: typing.IO
if args.grep_output_filename: if args.grep_output_filename:
grep_output_file = open(args.grep_output_filename, 'w') grep_output_file = open(args.grep_output_filename, 'w')
else: else:
@ -168,7 +171,7 @@ for desc, query in QUERIES.items():
if not rrows: if not rrows:
print(f'{desc:<55} {"N/A":>11}') print(f'{desc:<55} {"N/A":>11}')
elif desc.startswith('04'): elif desc.startswith('04'):
homedir = os.getenv('HOME') homedir = os.getenv('HOME', '')
print(f'{desc}\n See {grep_output_file.name}') print(f'{desc}\n See {grep_output_file.name}')
grep_rows = [format_record_for_grep(row, homedir) for row in rrows] grep_rows = [format_record_for_grep(row, homedir) for row in rrows]
print(tabulate(grep_rows), file=grep_output_file) print(tabulate(grep_rows), file=grep_output_file)

View file

@ -10,14 +10,15 @@ import argparse
import csv import csv
import datetime import datetime
import decimal import decimal
from typing import Dict, List, Tuple
from beancount import loader from beancount import loader
from beancount.query.query import run_query from beancount.query.query import run_query
from thefuzz import fuzz from thefuzz import fuzz # type: ignore
# NOTE: Statement doesn't seem to give us a running balance or a final total. # NOTE: Statement doesn't seem to give us a running balance or a final total.
def standardize_amex_record(row): def standardize_amex_record(row: Dict) -> Dict:
return { return {
'date': datetime.datetime.strptime(row['Date'], '%m/%d/%Y').date(), 'date': datetime.datetime.strptime(row['Date'], '%m/%d/%Y').date(),
'amount': -1 * decimal.Decimal(row['Amount']), 'amount': -1 * decimal.Decimal(row['Amount']),
@ -25,7 +26,7 @@ def standardize_amex_record(row):
} }
def standardize_beancount_record(row): def standardize_beancount_record(row) -> Dict: # type: ignore[no-untyped-def]
return { return {
'date': row.date, 'date': row.date,
'amount': row.number_cost_position, 'amount': row.number_cost_position,
@ -33,15 +34,15 @@ def standardize_beancount_record(row):
} }
def format_record(record): def format_record(record: Dict) -> str:
return f"{record['date'].isoformat()}: {record['amount']:>8} {record['payee'][:20]:<20}" return f"{record['date'].isoformat()}: {record['amount']:>8} {record['payee'][:20]:<20}"
def sort_records(records): def sort_records(records: List) -> List:
return sorted(records, key=lambda x: (x['date'], x['amount'])) return sorted(records, key=lambda x: (x['date'], x['amount']))
def records_match(r1, r2): def records_match(r1: Dict, r2: Dict) -> Tuple[bool, str]:
"""Do these records represent the same transaction?""" """Do these records represent the same transaction?"""
date_matches = r1['date'] >= r2['date'] - datetime.timedelta(days=1) and r1['date'] <= r2['date'] + datetime.timedelta(days=1) date_matches = r1['date'] >= r2['date'] - datetime.timedelta(days=1) and r1['date'] <= r2['date'] + datetime.timedelta(days=1)
amount_matches = r1['amount'] == r2['amount'] amount_matches = r1['amount'] == r2['amount']