"""test_reports_ledger.py - Unit tests for general ledger report"""
# Copyright © 2020  Brett Smith
# License: AGPLv3-or-later WITH Beancount-Plugin-Additional-Permission-1.0
#
# Full copyright and licensing details can be found at toplevel file
# LICENSE.txt in the repository.

import collections
import contextlib
import copy
import datetime
import io
import itertools
import re

import pytest

from . import testutil

import odf.table
import odf.text

from beancount.core import data as bc_data
from beancount import loader as bc_loader
from conservancy_beancount import data
from conservancy_beancount.reports import core
from conservancy_beancount.reports import ledger

clean_account_meta = contextlib.contextmanager(testutil.clean_account_meta)

Acct = data.Account

_ledger_load = bc_loader.load_file(testutil.test_path('books/ledger.beancount'))
DEFAULT_REPORT_SHEETS = [
    'Income',
    'Expenses:Payroll',
    'Expenses',
    'Equity',
    'Assets:Receivable',
    'Liabilities:Payable',
    'Assets:PayPal',
    'Assets',
    'Liabilities',
]
PROJECT_REPORT_SHEETS = [
    'Balance',
    'Income',
    *DEFAULT_REPORT_SHEETS[2:5],
    'Assets:Prepaid',
    'Liabilities:UnearnedIncome',
    'Liabilities:Payable',
]
OVERSIZE_RE = re.compile(
    r'^([A-Za-z0-9:]+) has ([0-9,]+) rows, over size ([0-9,]+)$'
)
START_DATE = datetime.date(2018, 3, 1)
MID_DATE = datetime.date(2019, 3, 1)
STOP_DATE = datetime.date(2020, 3, 1)

REPORT_KWARGS = [
    {'report_class': ledger.LedgerODS},
    *({'report_class': ledger.TransactionODS, 'txn_filter': flags}
      for flags in ledger.ReportType if flags & ledger.ReportType.ALL_TRANSACTIONS),
]

@pytest.fixture
def ledger_entries():
    return copy.deepcopy(_ledger_load[0])

def iter_accounts(entries):
    for entry in entries:
        if isinstance(entry, bc_data.Open):
            yield entry.account

class NotFound(Exception): pass
class NoSheet(NotFound): pass
class NoHeader(NotFound): pass

class ExpectedPostings(core.RelatedPostings):
    @classmethod
    def find_section(cls, ods, account):
        for sheet in ods.getElementsByType(odf.table.Table):
            sheet_account = sheet.getAttribute('name').replace(' ', ':')
            if sheet_account and account.is_under(sheet_account):
                break
        else:
            raise NoSheet(account)
        rows = iter(sheet.getElementsByType(odf.table.TableRow))
        for row in rows:
            cells = row.childNodes
            if (len(cells) >= 3
                and cells[1].text == account
                and not cells[0].text):
                break
        else:
            raise NoHeader(account)
        return rows

    @classmethod
    def check_not_in_report(cls, ods, *accounts):
        for account in accounts:
            with pytest.raises(NotFound):
                cls.find_section(ods, data.Account(account))

    @classmethod
    def check_in_report(cls, ods, account,
                        start_date=START_DATE, end_date=STOP_DATE, txn_filter=None):
        date = end_date + datetime.timedelta(days=1)
        txn = testutil.Transaction(date=date, postings=[
            (account, 0),
        ])
        related = cls(data.Posting.from_txn(txn))
        if txn_filter is None:
            related.check_report(ods, start_date, end_date)
        else:
            related.check_txn_report(ods, txn_filter, start_date, end_date)

    def slice_date_range(self, start_date, end_date):
        postings = enumerate(self)
        for start_index, post in postings:
            if start_date <= post.meta.date:
                break
        else:
            start_index += 1
        if end_date <= post.meta.date:
            end_index = start_index
        else:
            for end_index, post in postings:
                if end_date <= post.meta.date:
                    break
            else:
                end_index = None
        return (self[:start_index].balance_at_cost(),
                self[start_index:end_index])

    def check_report(self, ods, start_date, end_date, expect_totals=True):
        account = self[0].account
        norm_func = core.normalize_amount_func(account)
        open_bal, expect_posts = self.slice_date_range(start_date, end_date)
        open_bal = norm_func(open_bal)
        closing_bal = norm_func(expect_posts.balance_at_cost())
        rows = self.find_section(ods, account)
        if expect_totals and account.is_under('Assets', 'Liabilities'):
            opening_row = testutil.ODSCell.from_row(next(rows))
            assert opening_row[0].value == start_date
            assert opening_row[4].text == open_bal.format(None, empty='0', sep='\0')
            closing_bal += open_bal
        for expected in expect_posts:
            cells = iter(testutil.ODSCell.from_row(next(rows)))
            assert next(cells).value == expected.meta.date
            assert next(cells).text == (expected.meta.get('entity') or '')
            assert next(cells).text == (expected.meta.txn.narration or '')
            if expected.cost is None:
                assert not next(cells).text
                assert next(cells).value == norm_func(expected.units.number)
            else:
                assert next(cells).value == norm_func(expected.units.number)
                assert next(cells).value == norm_func(expected.at_cost().number)
        if expect_totals:
            closing_row = testutil.ODSCell.from_row(next(rows))
            assert closing_row[0].value == end_date
            empty = '$0.00' if expect_posts else '0'
            assert closing_row[4].text == closing_bal.format(None, empty=empty, sep='\0')

    def _post_data_from_row(self, row):
        if row[4].text:
            number = row[4].value
            match = re.search(r'([A-Z]{3})\d*Cell', row[4].getAttribute('stylename') or '')
            assert match
            currency = match.group(1)
        else:
            number = row[5].value
            currency = 'USD'
        return (row[2].text, row[3].text, number, currency)

    def _post_data_from_post(self, post, norm_func):
        return (
            post.account,
            post.meta.get('entity') or '',
            norm_func(post.units.number),
            post.units.currency,
        )

    def check_txn_report(self, ods, txn_filter, start_date, end_date, expect_totals=True):
        account = self[0].account
        norm_func = core.normalize_amount_func(account)
        open_bal, expect_posts = self.slice_date_range(start_date, end_date)
        open_bal = norm_func(open_bal)
        period_bal = core.MutableBalance()
        rows = self.find_section(ods, account)
        if (expect_totals
            and txn_filter == ledger.ReportType.ALL_TRANSACTIONS
            and account.is_under('Assets', 'Liabilities')):
            opening_row = testutil.ODSCell.from_row(next(rows))
            assert opening_row[0].value == start_date
            assert opening_row[5].text == open_bal.format(None, empty='0', sep='\0')
            period_bal += open_bal
        last_txn = None
        for post in expect_posts:
            txn = post.meta.txn
            post_flag = ledger.ReportType.post_flag(post)
            if txn is last_txn or (not txn_filter & post_flag):
                continue
            last_txn = txn
            row1 = testutil.ODSCell.from_row(next(rows))
            assert row1[0].value == txn.date
            assert row1[1].text == (txn.narration or '')
            expected = {self._post_data_from_post(post, norm_func)
                        for post in txn.postings}
            actual = {self._post_data_from_row(testutil.ODSCell.from_row(row))
                      for row in itertools.islice(rows, len(txn.postings) - 1)}
            actual.add(self._post_data_from_row(row1))
            assert actual == expected
            for post_acct, _, number, currency in expected:
                if post_acct == account:
                    period_bal += testutil.Amount(number, currency)
        if expect_totals:
            closing_row = testutil.ODSCell.from_row(next(rows))
            assert closing_row[0].value == end_date
            empty = '$0.00' if period_bal else '0'
            assert closing_row[5].text == period_bal.format(None, empty=empty, sep='\0')


def get_sheet_names(ods):
    return [sheet.getAttribute('name').replace(' ', ':')
            for sheet in ods.getElementsByType(odf.table.Table)]

def check_oversize_logs(caplog, accounts, sheet_size):
    actual = {}
    for log in caplog.records:
        match = OVERSIZE_RE.match(log.message)
        if match:
            assert int(match.group(3).replace(',', '')) == sheet_size
            actual[match.group(1)] = int(match.group(2).replace(',', ''))
    expected = {name: size for name, size in accounts.items() if size > sheet_size}
    assert actual == expected

def test_plan_sheets_no_change():
    have = {
        Acct('Assets:Cash'): 10,
        Acct('Income:Donations'): 20,
    }
    want = ['Assets', 'Income']
    actual = ledger.LedgerODS.plan_sheets(have, want.copy(), 100)
    assert actual == want

@pytest.mark.parametrize('have', [
    {},
    {Acct('Income:Other'): 10},
    {Acct('Assets:Checking'): 20, Acct('Expenses:Other'): 15},
])
def test_plan_sheets_includes_accounts_without_transactions(have):
    want = ['Assets', 'Income', 'Expenses']
    actual = ledger.LedgerODS.plan_sheets(have, want.copy(), 100)
    assert actual == want

def test_plan_sheets_single_split():
    have = {
        Acct('Assets:Cash'): 60,
        Acct('Assets:Checking'): 80,
        Acct('Income:Donations'): 50,
        Acct('Expenses:Travel'): 90,
        Acct('Expenses:FilingFees'): 25,
    }
    want = ['Assets', 'Income', 'Expenses']
    actual = ledger.LedgerODS.plan_sheets(have, want, 100)
    assert actual == [
        'Assets:Checking',
        'Assets',
        'Income',
        'Expenses:Travel',
        'Expenses',
    ]

def test_plan_sheets_split_subtree():
    have = {
        Acct('Assets:Bank1:Checking'): 80,
        Acct('Assets:Bank1:Savings'): 10,
        Acct('Assets:Cash:USD'): 20,
        Acct('Assets:Cash:EUR'): 15,
    }
    actual = ledger.LedgerODS.plan_sheets(have, ['Assets'], 100)
    assert actual == ['Assets:Bank1', 'Assets']

def test_plan_sheets_ambiguous_split():
    have = {
        Acct('Assets:Bank1:Checking'): 80,
        Acct('Assets:Bank1:Savings'): 40,
        Acct('Assets:Receivable:Accounts'): 40,
        Acct('Assets:Cash'): 10,
    }
    actual = ledger.LedgerODS.plan_sheets(have, ['Assets'], 100)
    # :Savings cannot fit with :Checking, so it's important that the return
    # value disambiguate that.
    assert actual == ['Assets:Bank1:Checking', 'Assets']

def test_plan_sheets_oversize(caplog):
    have = {
        Acct('Assets:Checking'): 150,
        Acct('Assets:Cash'): 50,
    }
    actual = ledger.LedgerODS.plan_sheets(have, ['Assets'], 100)
    assert actual == ['Assets:Checking', 'Assets']
    check_oversize_logs(caplog, have, 100)

def test_plan_sheets_all_oversize(caplog):
    have = {
        Acct('Assets:Checking'): 150,
        Acct('Assets:Cash'): 150,
    }
    actual = ledger.LedgerODS.plan_sheets(have, ['Assets'], 100)
    # In this case, each account should appear in alphabetical order.
    assert actual == ['Assets:Cash', 'Assets:Checking']
    check_oversize_logs(caplog, have, 100)

def test_plan_sheets_full_split_required(caplog):
    have = {
        Acct('Assets:Bank:Savings'): 98,
        Acct('Assets:Bank:Checking'): 96,
        Acct('Assets:Bank:Investment'): 94,
    }
    actual = ledger.LedgerODS.plan_sheets(have, ['Assets'], 100)
    assert actual == ['Assets:Bank:Checking', 'Assets:Bank:Savings', 'Assets']
    assert not caplog.records

def build_report(ledger_entries, start_date, stop_date, *args,
                 report_class=ledger.LedgerODS, **kwargs):
    postings = list(data.Posting.from_entries(iter(ledger_entries)))
    with clean_account_meta():
        data.Account.load_openings_and_closings(iter(ledger_entries))
        report = report_class(start_date, stop_date, *args, **kwargs)
        report.write(iter(postings))
    return postings, report

@pytest.mark.parametrize('report_kwargs', iter(REPORT_KWARGS))
@pytest.mark.parametrize('start_date,stop_date', [
    (START_DATE, STOP_DATE),
    (START_DATE, MID_DATE),
    (MID_DATE, STOP_DATE),
    (START_DATE.replace(month=6), START_DATE.replace(month=12)),
    (STOP_DATE, STOP_DATE.replace(month=12)),
])
def test_date_range_report(ledger_entries, start_date, stop_date, report_kwargs):
    txn_filter = report_kwargs.get('txn_filter')
    postings, report = build_report(ledger_entries, start_date, stop_date, **report_kwargs)
    expected = dict(ExpectedPostings.group_by_account(postings))
    for account in iter_accounts(ledger_entries):
        try:
            related = expected[account]
        except KeyError:
            ExpectedPostings.check_in_report(
                report.document, account, start_date, stop_date, txn_filter,
            )
        else:
            if txn_filter is None:
                related.check_report(report.document, start_date, stop_date)
            else:
                related.check_txn_report(
                    report.document, txn_filter, start_date, stop_date,
                )

@pytest.mark.parametrize('report_kwargs', iter(REPORT_KWARGS))
@pytest.mark.parametrize('tot_accts', [
    (),
    ('Assets', 'Liabilities'),
    ('Income', 'Expenses'),
    ('Assets', 'Liabilities', 'Income', 'Expenses'),
])
def test_report_filter_totals(ledger_entries, tot_accts, report_kwargs):
    txn_filter = report_kwargs.get('txn_filter')
    postings, report = build_report(ledger_entries, START_DATE, STOP_DATE,
                                    totals_with_entries=tot_accts,
                                    totals_without_entries=tot_accts,
                                    **report_kwargs)
    expected = dict(ExpectedPostings.group_by_account(postings))
    for account in iter_accounts(ledger_entries):
        expect_totals = account.startswith(tot_accts)
        if account in expected and expected[account][-1].meta.date >= START_DATE:
            if txn_filter is None:
                expected[account].check_report(
                    report.document, START_DATE, STOP_DATE, expect_totals=expect_totals,
                )
            else:
                expected[account].check_txn_report(
                    report.document, txn_filter,
                    START_DATE, STOP_DATE, expect_totals=expect_totals,
                )
        elif expect_totals:
            ExpectedPostings.check_in_report(
                report.document, account, START_DATE, STOP_DATE, txn_filter,
            )
        else:
            ExpectedPostings.check_not_in_report(report.document, account)

@pytest.mark.parametrize('report_kwargs', iter(REPORT_KWARGS))
@pytest.mark.parametrize('accounts', [
    ('Income', 'Expenses'),
    ('Assets:Receivable', 'Liabilities:Payable'),
])
def test_account_names_report(ledger_entries, accounts, report_kwargs):
    txn_filter = report_kwargs.get('txn_filter')
    postings, report = build_report(ledger_entries, START_DATE, STOP_DATE,
                                    accounts, **report_kwargs)
    expected = dict(ExpectedPostings.group_by_account(postings))
    for account in iter_accounts(ledger_entries):
        if not account.startswith(accounts):
            ExpectedPostings.check_not_in_report(report.document, account)
        # This account is reportable but has no postings
        elif account == 'Expenses:Payroll':
            ExpectedPostings.check_in_report(
                report.document, account, START_DATE, STOP_DATE, txn_filter,
            )
        elif txn_filter is None:
            expected[account].check_report(report.document, START_DATE, STOP_DATE)
        else:
            expected[account].check_txn_report(
                report.document, txn_filter, START_DATE, STOP_DATE,
            )

def run_main(arglist, config=None):
    if config is None:
        config = testutil.TestConfig(
            books_path=testutil.test_path('books/ledger.beancount'),
            rt_client=testutil.RTClient(),
        )
    arglist.insert(0, '--output-file=-')
    output = io.BytesIO()
    errors = io.StringIO()
    with clean_account_meta():
        retcode = ledger.main(arglist, output, errors, config)
    output.seek(0)
    return retcode, output, errors

def test_main(ledger_entries):
    retcode, output, errors = run_main([
        '-b', START_DATE.isoformat(),
        '-e', STOP_DATE.isoformat(),
    ])
    output.seek(0)
    assert not errors.getvalue()
    assert retcode == 0
    ods = odf.opendocument.load(output)
    assert get_sheet_names(ods) == DEFAULT_REPORT_SHEETS[:]
    postings = data.Posting.from_entries(iter(ledger_entries))
    expected = dict(ExpectedPostings.group_by_account(postings))
    for account in iter_accounts(ledger_entries):
        try:
            expected[account].check_report(ods, START_DATE, STOP_DATE)
        except KeyError:
            ExpectedPostings.check_in_report(ods, account)

@pytest.mark.parametrize('acct_arg', [
    'Liabilities',
    'Accounts payable',
])
def test_main_account_limit(ledger_entries, acct_arg):
    retcode, output, errors = run_main([
        '-a', acct_arg,
        '-b', START_DATE.isoformat(),
        '-e', STOP_DATE.isoformat(),
    ])
    assert not errors.getvalue()
    assert retcode == 0
    ods = odf.opendocument.load(output)
    assert get_sheet_names(ods) == ['Liabilities']
    postings = data.Posting.from_entries(ledger_entries)
    for account, expected in ExpectedPostings.group_by_account(postings):
        if account == 'Liabilities:UnearnedIncome':
            should_find = acct_arg == 'Liabilities'
        else:
            should_find = account.startswith('Liabilities')
        try:
            expected.check_report(ods, START_DATE, STOP_DATE)
        except NotFound:
            assert not should_find
        else:
            assert should_find

def test_main_account_classification_splits_hierarchy(ledger_entries):
    retcode, output, errors = run_main([
        '-a', 'Cash',
        '-b', START_DATE.isoformat(),
        '-e', STOP_DATE.isoformat(),
    ])
    assert not errors.getvalue()
    assert retcode == 0
    ods = odf.opendocument.load(output)
    assert get_sheet_names(ods) == ['Assets']
    postings = data.Posting.from_entries(ledger_entries)
    for account, expected in ExpectedPostings.group_by_account(postings):
        should_find = (account == 'Assets:Checking' or account == 'Assets:PayPal')
        try:
            expected.check_report(ods, START_DATE, STOP_DATE)
        except NotFound:
            assert not should_find, f"{account} not found in report"
        else:
            assert should_find, f"{account} in report but should be excluded"

@pytest.mark.parametrize('project,start_date,stop_date', [
    ('eighteen', START_DATE, MID_DATE.replace(day=30)),
    ('nineteen', MID_DATE, STOP_DATE),
])
def test_main_project_report(ledger_entries, project, start_date, stop_date):
    postings = data.Posting.from_entries(iter(ledger_entries))
    for key, related in ExpectedPostings.group_by_meta(postings, 'project'):
        if key == project:
            break
    assert key == project
    retcode, output, errors = run_main([
        f'--begin={start_date.isoformat()}',
        f'--end={stop_date.isoformat()}',
        '--report-type=fund_ledger',
        project,
    ])
    assert not errors.getvalue()
    assert retcode == 0
    ods = odf.opendocument.load(output)
    assert get_sheet_names(ods) == PROJECT_REPORT_SHEETS[:]
    expected = dict(ExpectedPostings.group_by_account(related))
    for account in iter_accounts(ledger_entries):
        try:
            expected[account].check_report(ods, start_date, stop_date)
        except KeyError:
            ExpectedPostings.check_not_in_report(ods, account)

@pytest.mark.parametrize('flag', [
    '--disbursements',
    '--receipts',
])
def test_main_cash_report(ledger_entries, flag):
    if flag == '--receipts':
        txn_filter = ledger.ReportType.CREDIT_TRANSACTIONS
    else:
        txn_filter = ledger.ReportType.DEBIT_TRANSACTIONS
    retcode, output, errors = run_main([
        flag,
        '-b', START_DATE.isoformat(),
        '-e', STOP_DATE.isoformat(),
    ])
    assert not errors.getvalue()
    assert retcode == 0
    ods = odf.opendocument.load(output)
    postings = data.Posting.from_entries(ledger_entries)
    for account, expected in ExpectedPostings.group_by_account(postings):
        if account == 'Assets:Checking' or account == 'Assets:PayPal':
            expected.check_txn_report(ods, txn_filter, START_DATE, STOP_DATE)
        else:
            expected.check_not_in_report(ods)

@pytest.mark.parametrize('arg', [
    'Assets:NoneSuchBank',
    'Funny money',
])
def test_main_invalid_account(caplog, arg):
    retcode, output, errors = run_main(['-a', arg])
    assert retcode == 2
    assert any(log.message.endswith(f': {arg!r}') for log in caplog.records)

def test_main_no_postings(caplog):
    retcode, output, errors = run_main(['NonexistentProject'])
    assert retcode == 65
    assert any(log.levelname == 'WARNING' for log in caplog.records)