From 70057fe3831962f4eb98c5310d882acb6e38a8c7 Mon Sep 17 00:00:00 2001 From: Brett Smith Date: Wed, 3 Jun 2020 19:03:02 -0400 Subject: [PATCH] reports: Start BaseODS class. --- conservancy_beancount/reports/core.py | 540 +++++++++++++++++++++++++- setup.cfg | 2 +- setup.py | 2 + tests/test_reports_spreadsheet.py | 462 ++++++++++++++++++++++ tests/testutil.py | 46 +++ 5 files changed, 1049 insertions(+), 3 deletions(-) diff --git a/conservancy_beancount/reports/core.py b/conservancy_beancount/reports/core.py index f0f3723..72ade4d 100644 --- a/conservancy_beancount/reports/core.py +++ b/conservancy_beancount/reports/core.py @@ -16,19 +16,34 @@ import abc import collections +import datetime +import itertools import operator +import re + +import babel.core # type:ignore[import] +import babel.numbers # type:ignore[import] + +import odf.config # type:ignore[import] +import odf.element # type:ignore[import] +import odf.number # type:ignore[import] +import odf.opendocument # type:ignore[import] +import odf.style # type:ignore[import] +import odf.table # type:ignore[import] +import odf.text # type:ignore[import] from decimal import Decimal - -import babel.numbers # type:ignore[import] +from pathlib import Path from beancount.core import amount as bc_amount from .. import data from typing import ( + cast, overload, Any, + BinaryIO, Callable, DefaultDict, Dict, @@ -37,6 +52,7 @@ from typing import ( Iterator, List, Mapping, + MutableMapping, Optional, Sequence, Set, @@ -52,6 +68,8 @@ from ..beancount_types import ( DecimalCompat = data.DecimalCompat BalanceType = TypeVar('BalanceType', bound='Balance') +ElementType = Callable[..., odf.element.Element] +LinkType = Union[str, Tuple[str, Optional[str]]] RelatedType = TypeVar('RelatedType', bound='RelatedPostings') RT = TypeVar('RT', bound=Sequence) ST = TypeVar('ST') @@ -348,3 +366,521 @@ class BaseSpreadsheet(Generic[RT, ST], metaclass=abc.ABCMeta): if should_end: self.end_section(section) self.end_spreadsheet() + + +class BaseODS(BaseSpreadsheet[RT, ST], metaclass=abc.ABCMeta): + """Abstract base class to help write OpenDocument spreadsheets + + This class provides the very core logic to write an arbitrary set of data + rows to an OpenDocument spreadsheet. It provides helper methods for + building sheets, rows, and cells. + + See also the BaseSpreadsheet base class for additional documentation about + methods you must and can define, the definition of RT and ST, etc. + """ + def __init__(self) -> None: + self.locale = babel.core.Locale.default('LC_MONETARY') + self.currency_fmt_key = 'accounting' + self._name_counter = itertools.count(1) + self._currency_style_cache: MutableMapping[str, odf.style.Style] = {} + self.document = odf.opendocument.OpenDocumentSpreadsheet() + self.init_settings() + self.init_styles() + self.sheet = self.use_sheet("Report") + + ### Low-level document tree manipulation + # The *intent* is that you only need to use these if you're adding new + # methods to manipulate document settings or styles. + + def copy_element(self, elem: odf.element.Element) -> odf.element.Element: + qattrs = dict(self.iter_qattributes(elem)) + retval = odf.element.Element(qname=elem.qname, qattributes=qattrs) + try: + orig_name = retval.getAttribute('name') + except ValueError: + orig_name = None + if orig_name is not None: + retval.setAttribute('name', f'{orig_name}{next(self._name_counter)}') + return retval + + def ensure_child(self, + parent: odf.element.Element, + child_type: ElementType, + **kwargs: Any, + ) -> odf.element.Element: + new_child = child_type(**kwargs) + found_child = self.find_child(parent, new_child) + if found_child is None: + parent.addElement(new_child) + return parent.lastChild + else: + return found_child + + def ensure_config_map_entry(self, + root: odf.element.Element, + map_name: str, + entry_name: str, + ) -> odf.element.Element: + """Return a ``ConfigItemMapEntry`` under ``root`` + + This method ensures there's a ``ConfigItemMapNamed`` named ``map_name`` + under ``root``, and a ``ConfigItemMapEntry`` named ``entry_name`` under + that. Return the ``ConfigItemMapEntry`` element. + """ + config_map = self.ensure_child(root, odf.config.ConfigItemMapNamed, name=map_name) + return self.ensure_child(config_map, odf.config.ConfigItemMapEntry, name=entry_name) + + def find_child(self, + parent: odf.element.Element, + child: odf.element.Element, + ) -> Optional[odf.element.Element]: + attrs = {k: v for k, v in self.iter_attributes(child)} + if not attrs: + return None + for elem in parent.childNodes: + if (elem.qname == child.qname + and all(elem.getAttribute(k) == v for k, v in attrs.items())): + return elem + return None + + def iter_attributes(self, elem: odf.element.Element) -> Iterator[Tuple[str, str]]: + for (_, key), value in self.iter_qattributes(elem): + yield key.lower().replace('-', ''), value + + def iter_qattributes(self, elem: odf.element.Element) -> Iterator[Tuple[Tuple[str, str], str]]: + if elem.attributes: + yield from elem.attributes.items() + + def replace_child(self, + parent: odf.element.Element, + child_type: ElementType, + **kwargs: Any, + ) -> odf.element.Element: + new_child = child_type(**kwargs) + found_child = self.find_child(parent, new_child) + parent.insertBefore(new_child, found_child) + if found_child is not None: + parent.removeChild(found_child) + return new_child + + def set_config(self, + root: odf.element.Element, + name: str, + value: Union[bool, int, str], + config_type: Optional[str]=None, + ) -> None: + """Ensure ``root`` has a ``ConfigItem`` with the given name, type, and value""" + value_s = str(value) + if isinstance(value, bool): + value_s = str(value).lower() + default_type = 'boolean' + elif isinstance(value, str): + default_type = 'string' + if config_type is None: + try: + config_type = default_type + except NameError: + raise ValueError( + f"need config_type for {type(value).__name__} value", + ) from None + item = self.replace_child( + root, odf.config.ConfigItem, name=name, type=config_type, + ) + item.addText(value_s) + + ### Styles + + def _build_currency_style( + self, + root: odf.element.Element, + locale: babel.core.Locale, + code: str, + fmt_index: int, + properties: Optional[odf.style.TextProperties]=None, + *, + fmt_key: Optional[str]=None, + volatile: bool=False, + minintegerdigits: int=1, + ) -> odf.element.Element: + if fmt_key is None: + fmt_key = self.currency_fmt_key + pattern = locale.currency_formats[fmt_key] + fmts = pattern.pattern.split(';') + try: + fmt = fmts[fmt_index] + except IndexError: + fmt = fmts[0] + grouping = pattern.grouping[0] + else: + grouping = pattern.grouping[fmt_index] + zero_s = babel.numbers.format_currency(0, code, '##0.0', locale) + try: + decimal_index = zero_s.rindex('.') + 1 + except ValueError: + decimalplaces = 0 + else: + decimalplaces = len(zero_s) - decimal_index + style = self.replace_child( + root, + odf.number.CurrencyStyle, + name=f'{code}{next(self._name_counter)}', + ) + style.setAttribute('volatile', 'true' if volatile else 'false') + if properties is not None: + style.addElement(properties) + for part in re.split(r"(¤+|[#0,.]+|'[^']+')", fmt): + if not part: + pass + elif not part.strip('#0,.'): + style.addElement(odf.number.Number( + decimalplaces=str(decimalplaces), + grouping='true' if grouping else 'false', + minintegerdigits=str(minintegerdigits), + )) + elif part == '¤': + style.addElement(odf.number.CurrencySymbol( + country=locale.territory, + language=locale.language, + text=babel.numbers.get_currency_symbol(code, locale), + )) + elif part == '¤¤': + style.addElement(odf.number.Text(text=code)) + else: + style.addElement(odf.number.Text(text=part.strip("'"))) + return style + + def currency_style( + self, + code: str, + locale: Optional[babel.core.Locale]=None, + negative_properties: Optional[odf.style.TextProperties]=None, + positive_properties: Optional[odf.style.TextProperties]=None, + root: odf.element.Element=None, + ) -> odf.style.Style: + """Create and return a spreadsheet style to format currency data + + Given a currency code and a locale, this method will create all the + styles necessary to format the currency according to the locale's + rules, including rendering of decimal points and negative values. + + You may optionally pass in TextProperties to use for negative and + positive amounts, respectively. If you don't, negative values will + automatically be rendered in red (text color #f00). + + Results are cached. If you repeatedly call this method with the same + arguments, you'll keep getting the same style returned, which will + only be added to the document once. + """ + if locale is None: + locale = self.locale + if negative_properties is None: + negative_properties = odf.style.TextProperties(color='#ff0000') + if root is None: + root = self.document.styles + cache_parts = [str(id(root)), code, str(locale)] + for key, value in self.iter_attributes(negative_properties): + cache_parts.append(f'{key}={value}') + if positive_properties is not None: + cache_parts.append('') + for key, value in self.iter_attributes(positive_properties): + cache_parts.append(f'{key}={value}') + cache_key = '\0'.join(cache_parts) + try: + style = self._currency_style_cache[cache_key] + except KeyError: + pos_style = self._build_currency_style( + root, locale, code, 0, positive_properties, volatile=True, + ) + curr_style = self._build_currency_style( + root, locale, code, 1, negative_properties, + ) + curr_style.addElement(odf.style.Map( + condition='value()>=0', applystylename=pos_style, + )) + style = self.ensure_child( + self.document.styles, + odf.style.Style, + name=f'{curr_style.getAttribute("name")}Cell', + family='table-cell', + datastylename=curr_style, + ) + self._currency_style_cache[cache_key] = style + return style + + def _merge_style_iter_names( + self, + styles: Sequence[Union[str, odf.style.Style, None]], + ) -> Iterator[str]: + for source in styles: + if source is None: + continue + elif not isinstance(source, str): + source = source.getAttribute('name') + if source.startswith('Merge_'): + orig_names = iter(source.split('_')) + next(orig_names) + yield from orig_names + else: + yield source + + def _merge_styles(self, + new_style: odf.style.Style, + sources: Iterable[odf.style.Style], + ) -> None: + for elem in sources: + for key, new_value in self.iter_attributes(elem): + old_value = new_style.getAttribute(key) + if (key == 'name' + or key == 'displayname' + or old_value == new_value): + pass + elif old_value is None: + new_style.setAttribute(key, new_value) + else: + raise ValueError(f"cannot merge styles with conflicting {key}") + for child in elem.childNodes: + new_style.addElement(self.copy_element(child)) + + def merge_styles(self, + *styles: Union[str, odf.style.Style, None], + ) -> Optional[odf.style.Style]: + """Create a new style from multiple existing styles + + Given any number of existing styles, create a new style that combines + all of those styles' attributes and properties, add it to the document + styles, and return it. + + Styles can be specified by name, or by passing in their Style element. + For convenience, you can also pass in None as an argument; None will + simply be skipped. + + Results are cached. If you repeatedly call this method with the same + arguments, you'll keep getting the same style returned, which will + only be added to the document once. + + If you pass in zero real style arguments, returns None. + If you pass in one style argument, returns that style unchanged. + If you pass in a style that doesn't already exist in the document, + or if you pass in styles that can't be merged (because they have + conflicting attributes), raises ValueError. + """ + name_map: Dict[str, odf.style.Style] = {} + for name in self._merge_style_iter_names(styles): + source = odf.style.Style(name=name) + found = self.find_child(self.document.styles, source) + if found is None: + raise ValueError(f"no style named {name!r}") + name_map[name] = found + if not name_map: + retval = None + elif len(name_map) == 1: + _, retval = name_map.popitem() + else: + new_name = f'Merge_{"_".join(sorted(name_map))}' + retval = self.ensure_child( + self.document.styles, odf.style.Style, name=new_name, + ) + if retval.firstChild is None: + self._merge_styles(retval, name_map.values()) + return retval + + ### Sheets + + def lock_first_row(self, sheet: Optional[odf.table.Table]=None) -> None: + """Lock the first row of cells under the given sheet + + This method sets all the appropriate settings to "lock" the first row + of cells in a sheet, so it stays in view even as the viewer scrolls + through rows. If a sheet is not given, works on ``self.sheet``. + """ + if sheet is None: + sheet = self.sheet + config_map = self.ensure_config_map_entry( + self.view, 'Tables', sheet.getAttribute('name'), + ) + self.set_config(config_map, 'PositionBottom', 1, 'int') + self.set_config(config_map, 'VerticalSplitMode', 2, 'short') + self.set_config(config_map, 'VerticalSplitPosition', 1, 'short') + + def use_sheet(self, name: str) -> odf.table.Table: + """Switch the active sheet ``self.sheet`` to the one with the given name + + If there is no sheet with the given name, create it and append it to + the spreadsheet first. + + If the current active sheet is empty when this method is called, it + will be removed from the spreadsheet. + """ + try: + empty_sheet = not self.sheet.hasChildNodes() + except AttributeError: + empty_sheet = False + if empty_sheet: + self.document.spreadsheet.removeChild(self.sheet) + self.sheet = self.ensure_child( + self.document.spreadsheet, odf.table.Table, name=name, + ) + return self.sheet + + ### Initialization hooks + + def init_settings(self) -> None: + """Hook called to initialize settings + + This method is called by __init__ to populate + ``self.document.settings``. This implementation creates the barest + skeleton structure necessary to support other methods, in particular + ``lock_first_row``. + """ + view_settings = self.ensure_child( + self.document.settings, odf.config.ConfigItemSet, name='ooo:view-settings', + ) + views = self.ensure_child( + view_settings, odf.config.ConfigItemMapIndexed, name='Views', + ) + self.view = self.ensure_child(views, odf.config.ConfigItemMapEntry) + self.set_config(self.view, 'ViewId', 'view1') + + def init_styles(self) -> None: + """Hook called to initialize settings + + This method is called by __init__ to populate + ``self.document.styles``. This implementation creates basic building + block cell styles often used in financial reports. + """ + styles = self.document.styles + self.style_bold = self.ensure_child( + styles, odf.style.Style, name='Bold', family='table-cell', + ) + self.ensure_child( + self.style_bold, odf.style.TextProperties, fontweight='bold', + ) + self.style_starttext: odf.style.Style + self.style_centertext: odf.style.Style + self.style_endtext: odf.style.Style + for textalign in ['start', 'center', 'end']: + aligned_style = self.replace_child( + styles, odf.style.Style, name=f'{textalign.title()}Text', + ) + aligned_style.setAttribute('family', 'table-cell') + aligned_style.addElement(odf.style.ParagraphProperties(textalign=textalign)) + setattr(self, f'style_{textalign}text', aligned_style) + date_style = self.replace_child(styles, odf.number.DateStyle, name='ISODate') + date_style.addElement(odf.number.Year(style='long')) + date_style.addElement(odf.number.Text(text='-')) + date_style.addElement(odf.number.Month(style='long')) + date_style.addElement(odf.number.Text(text='-')) + date_style.addElement(odf.number.Day(style='long')) + self.style_date = self.ensure_child( + styles, + odf.style.Style, + name=f'{date_style.getAttribute("name")}Cell', + family='table-cell', + datastylename=date_style, + ) + self.style_dividerline = self.ensure_child( + styles, odf.style.Style, name='DividerLine', family='table-cell', + ) + self.ensure_child( + self.style_dividerline, + odf.style.TableCellProperties, + borderbottom='1pt solid #0000ff', + ) + + ### Rows and cells + + def add_row(self, *cells: odf.table.TableCell, **attrs: Any) -> odf.table.TableRow: + row = odf.table.TableRow(**attrs) + for cell in cells: + row.addElement(cell) + self.sheet.addElement(row) + return row + + def balance_cell(self, balance: Balance, **attrs: Any) -> odf.table.TableCell: + if balance.is_zero(): + return self.float_cell(0, **attrs) + elif len(balance) == 1: + amount = next(iter(balance.values())) + attrs['stylename'] = self.merge_styles( + attrs.get('stylename'), self.currency_style(amount.currency), + ) + return self.currency_cell(amount, **attrs) + else: + lines = [babel.numbers.format_currency( + number, currency, locale=self.locale, format_type=self.currency_fmt_key, + ) for number, currency in balance.values()] + attrs['stylename'] = self.merge_styles( + attrs.get('stylename'), self.style_endtext, + ) + return self.multiline_cell(lines, **attrs) + + def currency_cell(self, amount: data.Amount, **attrs: Any) -> odf.table.TableCell: + number, currency = amount + cell = odf.table.TableCell(valuetype='currency', value=number, **attrs) + cell.addElement(odf.text.P(text=babel.numbers.format_currency( + number, currency, locale=self.locale, format_type=self.currency_fmt_key, + ))) + return cell + + def date_cell(self, date: datetime.date, **attrs: Any) -> odf.table.TableCell: + attrs.setdefault('stylename', self.style_date) + cell = odf.table.TableCell(valuetype='date', datevalue=date, **attrs) + cell.addElement(odf.text.P(text=date.isoformat())) + return cell + + def float_cell(self, value: Union[int, float, Decimal], **attrs: Any) -> odf.table.TableCell: + cell = odf.table.TableCell(valuetype='float', value=value, **attrs) + cell.addElement(odf.text.P(text=str(value))) + return cell + + def multiline_cell(self, lines: Iterable[Any], **attrs: Any) -> odf.table.TableCell: + cell = odf.table.TableCell(valuetype='string', **attrs) + for line in lines: + cell.addElement(odf.text.P(text=str(line))) + return cell + + def multilink_cell(self, links: Iterable[LinkType], **attrs: Any) -> odf.table.TableCell: + cell = odf.table.TableCell(valuetype='string', **attrs) + for link in links: + if isinstance(link, tuple): + href, text = link + else: + href = link + text = None + cell.addElement(odf.text.P()) + cell.lastChild.addElement(odf.text.A( + type='simple', href=href, text=text, + )) + return cell + + def string_cell(self, text: str, **attrs: Any) -> odf.table.TableCell: + cell = odf.table.TableCell(valuetype='string', **attrs) + cell.addElement(odf.text.P(text=text)) + return cell + + def write_row(self, row: RT) -> None: + """Write a single row of input data to the spreadsheet + + This default implementation adds a single row to the spreadsheet, + with one cell per element of the row. The type of each element + determines what kind of cell is created. + + This implementation will help get you started, but you'll probably + want to override it to specify styles. + """ + out_row = odf.table.TableRow() + for cell_source in row: + if isinstance(cell_source, (int, float, Decimal)): + cell = self.float_cell(cell_source) + else: + cell = self.string_cell(cell_source) + out_row.addElement(cell) + self.sheet.addElement(out_row) + + def save_file(self, out_file: BinaryIO) -> None: + self.document.write(out_file) + + def save_path(self, path: Path, mode: str='w') -> None: + with path.open(f'{mode}b') as out_file: + out_file = cast(BinaryIO, out_file) + self.save_file(out_file) diff --git a/setup.cfg b/setup.cfg index 1a72506..882a777 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ test=pytest typecheck=pytest --addopts="--mypy conservancy_beancount" [mypy] -disallow_any_unimported = True +disallow_any_unimported = False disallow_untyped_calls = False disallow_untyped_defs = True show_error_codes = True diff --git a/setup.py b/setup.py index 92e658a..e310157 100755 --- a/setup.py +++ b/setup.py @@ -13,6 +13,8 @@ setup( install_requires=[ 'babel>=2.6', # Debian:python3-babel 'beancount>=2.2', # Debian:beancount + # 1.4.1 crashes when trying to save some documents. + 'odfpy>=1.4.0,!=1.4.1', # Debian:python3-odf 'PyYAML>=3.0', # Debian:python3-yaml 'regex', # Debian:python3-regex 'rt>=2.0', diff --git a/tests/test_reports_spreadsheet.py b/tests/test_reports_spreadsheet.py index a6e4be0..f560167 100644 --- a/tests/test_reports_spreadsheet.py +++ b/tests/test_reports_spreadsheet.py @@ -14,12 +14,53 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import datetime +import io +import itertools + import pytest +import babel.core +import babel.numbers +import odf.config +import odf.number +import odf.style +import odf.table +import odf.text + +from decimal import Decimal + from . import testutil from conservancy_beancount.reports import core +EN_US = babel.core.Locale('en', 'US') + +XML_NAMES_LIST = [None, 'ce2', 'xml_testname'] +XML_NAMES = itertools.cycle(XML_NAMES_LIST) + +CURRENCY_CELL_DATA = [ + (Decimal('10.101010'), 'BRL'), + (Decimal('-50.50'), 'GBP'), +] + +LINK_CELL_DATA = [ + 'https://example.org', + ('https://example.net', None), + ('https://example.com', 'Example Site'), +] + +NUMERIC_CELL_DATA = [ + 42, + 42.42, + Decimal('42.42'), +] + +STRING_CELL_DATA = [ + 'Example String', + LINK_CELL_DATA[0], +] + class BaseTester(core.BaseSpreadsheet[tuple, str]): def __init__(self): self.start_call = None @@ -47,10 +88,46 @@ class BaseTester(core.BaseSpreadsheet[tuple, str]): self.written_rows.append(key) +class ODSTester(core.BaseODS[tuple, str]): + def section_key(self, row): + return row[0] + + @pytest.fixture def spreadsheet(): return BaseTester() +@pytest.fixture +def ods_writer(): + retval = ODSTester() + retval.locale = EN_US + return retval + +def get_children(parent, child_type, **kwargs): + return [elem for elem in parent.getElementsByType(child_type) + if all(elem.getAttribute(k) == v for k, v in kwargs.items())] + +def get_child(parent, child_type, index=-1, **kwargs): + try: + return get_children(parent, child_type, **kwargs)[index] + except IndexError: + raise ValueError("no matching child found") from None + +def iter_text(parent): + for child in parent.childNodes: + if isinstance(child, odf.element.Text): + yield child.data + else: + yield from iter_text(child) + +def get_text(parent, joiner=''): + return joiner.join(iter_text(parent)) + +def check_currency_style(curr_style): + child_names = {child.tagName for child in curr_style.childNodes} + assert odf.number.Number().tagName in child_names + assert len(child_names) > 1 + def test_spreadsheet(spreadsheet): rows = [(ch, ii) for ii, ch in enumerate('aabbcc', 1)] spreadsheet.write(iter(rows)) @@ -77,3 +154,388 @@ def test_one_section_spreadsheet(spreadsheet): assert spreadsheet.started_sections == list('A') assert spreadsheet.start_call == [] assert spreadsheet.end_call == spreadsheet.ended_sections + +def test_ods_writer(ods_writer): + rows = [(ch, ii) for ii, ch in enumerate('aabbcc', 1)] + ods_writer.write(iter(rows)) + sheets = ods_writer.document.getElementsByType(odf.table.Table) + assert len(sheets) == 1 + for exp_row, act_row in zip(rows, testutil.ODSCell.from_sheet(sheets[0])): + expected1, expected2 = exp_row + actual1, actual2 = act_row + assert actual1.value_type == 'string' + assert actual1.text == expected1 + assert actual2.value_type == 'float' + assert actual2.value == expected2 + assert actual2.text == str(expected2) + +@pytest.mark.parametrize('save_type', ['file', 'path']) +def test_ods_writer_save(tmp_path, save_type): + rows = list(zip('ABC', 'abc')) + ods_writer = ODSTester() + ods_writer.write(iter(rows)) + if save_type == 'file': + ods_output = io.BytesIO() + ods_writer.save_file(ods_output) + ods_output.seek(0) + else: + ods_output = tmp_path / 'SavePathTest.ods' + ods_writer.save_path(ods_output) + for exp_row, act_row in zip(rows, testutil.ODSCell.from_ods_file(ods_output)): + assert len(exp_row) == len(act_row) + for expected, actual in zip(exp_row, act_row): + assert actual.value_type == 'string' + assert actual.value is None + assert actual.text == expected + +def test_ods_writer_use_sheet(ods_writer): + names = ['One', 'Two'] + for name in names: + ods_writer.use_sheet(name) + ods_writer.write([(name,)]) + ods_writer.use_sheet('End') + sheets = ods_writer.document.getElementsByType(odf.table.Table) + assert len(sheets) == len(names) + 1 + for name, sheet in zip(names, sheets): + texts = [cell.text for row in testutil.ODSCell.from_sheet(sheet) + for cell in row] + assert texts == [name] + +def test_ods_writer_use_sheet_returns_to_prior_sheets(ods_writer): + names = ['One', 'Two'] + sheets = [] + for name in names: + sheets.append(ods_writer.use_sheet(name)) + ods_writer.write([(name,)]) + for name, expected in zip(names, sheets): + actual = ods_writer.use_sheet(name) + assert actual is expected + texts = [cell.text for row in testutil.ODSCell.from_sheet(actual) + for cell in row] + assert texts == [name] + +def test_ods_writer_use_sheet_discards_unused_sheets(ods_writer): + ods_writer.use_sheet('Three') + ods_writer.use_sheet('Two') + ods_writer.use_sheet('One') + sheets = ods_writer.document.getElementsByType(odf.table.Table) + assert len(sheets) == 1 + assert sheets[0].getAttribute('name') == 'One' + +@pytest.mark.parametrize('currency_code', [ + 'USD', + 'EUR', + 'BRL', +]) +def test_ods_currency_style(ods_writer, currency_code): + style = ods_writer.currency_style(currency_code) + assert style.getAttribute('family') == 'table-cell' + curr_style = get_child( + ods_writer.document.styles, + odf.number.CurrencyStyle, + name=style.getAttribute('datastylename'), + ) + check_currency_style(curr_style) + mappings = get_children(curr_style, odf.style.Map) + assert mappings + for mapping in mappings: + check_currency_style(get_child( + ods_writer.document.styles, + odf.number.CurrencyStyle, + name=mapping.getAttribute('applystylename'), + )) + +def test_ods_currency_style_caches(ods_writer): + expected = ods_writer.currency_style('USD') + _ = ods_writer.currency_style('EUR') + actual = ods_writer.currency_style('USD') + assert actual is expected + +def test_ods_currency_style_cache_considers_properties(ods_writer): + bold_text = odf.style.TextProperties(fontweight='bold') + plain = ods_writer.currency_style('USD') + bold = ods_writer.currency_style('USD', positive_properties=bold_text) + assert plain is not bold + assert plain.getAttribute('name') != bold.getAttribute('name') + assert plain.getAttribute('datastylename') != bold.getAttribute('datastylename') + +@pytest.mark.parametrize('attr_name,child_type,checked_attr', [ + ('style_bold', odf.style.TextProperties, 'fontweight'), + ('style_centertext', odf.style.ParagraphProperties, 'textalign'), + ('style_dividerline', odf.style.TableCellProperties, 'borderbottom'), + ('style_endtext', odf.style.ParagraphProperties, 'textalign'), + ('style_starttext', odf.style.ParagraphProperties, 'textalign'), +]) +def test_ods_writer_style(ods_writer, attr_name, child_type, checked_attr): + style = getattr(ods_writer, attr_name) + actual = get_child( + ods_writer.document.styles, + odf.style.Style, + name=style.getAttribute('name'), + ) + assert actual is style + child = get_child(actual, child_type) + assert child.getAttribute(checked_attr) + +def test_ods_writer_merge_styles(ods_writer): + style = ods_writer.merge_styles(ods_writer.style_bold, ods_writer.style_dividerline) + actual = get_child( + ods_writer.document.styles, + odf.style.Style, + name=style.getAttribute('name'), + ) + assert actual is style + assert actual.getAttribute('family') == 'table-cell' + text_props = get_child(actual, odf.style.TextProperties) + assert text_props.getAttribute('fontweight') == 'bold' + cell_props = get_child(actual, odf.style.TableCellProperties) + assert cell_props.getAttribute('borderbottom') + +def test_ods_writer_merge_styles_with_children_and_attributes(ods_writer): + jpy_style = ods_writer.currency_style('JPY') + style = ods_writer.merge_styles(ods_writer.style_bold, jpy_style) + actual = get_child( + ods_writer.document.styles, + odf.style.Style, + name=style.getAttribute('name'), + ) + assert actual is style + assert actual.getAttribute('family') == 'table-cell' + assert actual.getAttribute('datastylename') == jpy_style.getAttribute('datastylename') + text_props = get_child(actual, odf.style.TextProperties) + assert text_props.getAttribute('fontweight') == 'bold' + +def test_ods_writer_merge_styles_caches(ods_writer): + sources = [ods_writer.style_bold, ods_writer.style_dividerline] + style1 = ods_writer.merge_styles(*sources) + style2 = ods_writer.merge_styles(*reversed(sources)) + assert style1 is style2 + assert get_child( + ods_writer.document.styles, + odf.style.Style, + name=style1.getAttribute('name'), + ) + +def test_ods_writer_layer_merge_styles(ods_writer): + usd_style = ods_writer.currency_style('USD') + layer1 = ods_writer.merge_styles(ods_writer.style_bold, ods_writer.style_dividerline) + layer2 = ods_writer.merge_styles(layer1, usd_style) + style_name = layer2.getAttribute('name') + assert style_name.count('Merge_') == 1 + actual = get_child( + ods_writer.document.styles, + odf.style.Style, + name=style_name, + ) + assert actual is layer2 + assert actual.getAttribute('family') == 'table-cell' + assert actual.getAttribute('datastylename') == usd_style.getAttribute('datastylename') + text_props = get_child(actual, odf.style.TextProperties) + assert text_props.getAttribute('fontweight') == 'bold' + cell_props = get_child(actual, odf.style.TableCellProperties) + assert cell_props.getAttribute('borderbottom') + +def test_ods_writer_merge_one_style(ods_writer): + actual = ods_writer.merge_styles(None, ods_writer.style_bold) + assert actual is ods_writer.style_bold + +def test_ods_writer_merge_no_styles(ods_writer): + assert ods_writer.merge_styles() is None + +def test_ods_writer_merge_nonexistent_style(ods_writer): + name = 'Non Existent Style' + with pytest.raises(ValueError, match=repr(name)): + ods_writer.merge_styles(ods_writer.style_bold, name) + +def test_ods_writer_merge_conflicting_styles(ods_writer): + sources = [ods_writer.currency_style(code) for code in ['USD', 'EUR']] + with pytest.raises(ValueError, match='conflicting datastylename'): + ods_writer.merge_styles(*sources) + +def test_ods_writer_date_style(ods_writer): + data_style_name = ods_writer.style_date.getAttribute('datastylename') + actual = get_child( + ods_writer.document.styles, + odf.style.Style, + family='table-cell', + datastylename=data_style_name, + ) + assert actual is ods_writer.style_date + data_style = get_child( + ods_writer.document.styles, + odf.number.DateStyle, + name=data_style_name, + ) + assert len(data_style.childNodes) == 5 + year, t1, month, t2, day = data_style.childNodes + assert year.qname[1] == 'year' + assert year.getAttribute('style') == 'long' + assert get_text(t1) == '-' + assert month.qname[1] == 'month' + assert month.getAttribute('style') == 'long' + assert get_text(t2) == '-' + assert day.qname[1] == 'day' + assert day.getAttribute('style') == 'long' + +def test_ods_lock_first_row(ods_writer): + ods_writer.lock_first_row() + view_settings = get_child( + ods_writer.document.settings, + odf.config.ConfigItemSet, + name='ooo:view-settings', + ) + views = get_child(view_settings, odf.config.ConfigItemMapIndexed, name='Views') + view1 = get_child(views, odf.config.ConfigItemMapEntry, index=0) + config_map = get_child(view1, odf.config.ConfigItemMapNamed, name='Tables') + sheet_name = ods_writer.sheet.getAttribute('name') + config_entry = get_child(config_map, odf.config.ConfigItemMapEntry, name=sheet_name) + for name, ctype, value in [ + ('PositionBottom', 'int', '1'), + ('VerticalSplitMode', 'short', '2'), + ('VerticalSplitPosition', 'short', '1'), + ]: + child = get_child(config_entry, odf.config.ConfigItem, name=name) + assert child.getAttribute('type') == ctype + assert child.firstChild.data == value + +@pytest.mark.parametrize('style_name', XML_NAMES_LIST) +def test_ods_writer_add_row(ods_writer, style_name): + cell1 = ods_writer.string_cell('one') + cell2 = ods_writer.float_cell(42.0) + row = ods_writer.add_row(cell1, cell2, defaultcellstylename=style_name) + assert ods_writer.sheet.lastChild is row + assert row.getAttribute('defaultcellstylename') == style_name + assert row.firstChild is cell1 + assert row.lastChild is cell2 + +def test_ods_writer_add_row_single_cell(ods_writer): + cell = ods_writer.multilink_cell(LINK_CELL_DATA[:1]) + row = ods_writer.add_row(cell) + assert ods_writer.sheet.lastChild is row + assert row.firstChild is cell + assert row.lastChild is cell + +def test_ods_writer_add_row_empty(ods_writer): + row = ods_writer.add_row(stylename='blank') + assert ods_writer.sheet.lastChild is row + assert row.firstChild is None + assert row.getAttribute('stylename') == 'blank' + +def test_ods_writer_balance_cell_empty(ods_writer): + balance = core.Balance() + cell = ods_writer.balance_cell(balance) + assert cell.value_type != 'string' + assert float(cell.value) == 0 + +def test_ods_writer_balance_cell_single_currency(ods_writer): + number = 250 + currency = 'EUR' + balance = core.Balance([testutil.Amount(number, currency)]) + cell = ods_writer.balance_cell(balance) + assert cell.value_type == 'currency' + assert Decimal(cell.value) == number + assert cell.text == babel.numbers.format_currency( + number, currency, locale=EN_US, format_type='accounting', + ) + +def test_ods_writer_balance_cell_multi_currency(ods_writer): + amounts = [testutil.Amount(num, code) for num, code in [ + (2500, 'RUB'), + (3500, 'BRL'), + ]] + balance = core.Balance(amounts) + cell = ods_writer.balance_cell(balance) + assert cell.text == '\0'.join(babel.numbers.format_currency( + number, currency, locale=EN_US, format_type='accounting', + ) for number, currency in amounts) + +@pytest.mark.parametrize('cell_source,style_name', testutil.combine_values( + CURRENCY_CELL_DATA, + XML_NAMES, +)) +def test_ods_writer_currency_cell(ods_writer, cell_source, style_name): + cell = ods_writer.currency_cell(cell_source, stylename=style_name) + number, currency = cell_source + assert cell.getAttribute('valuetype') == 'currency' + assert cell.getAttribute('value') == str(number) + assert cell.getAttribute('stylename') == style_name + expected = babel.numbers.format_currency( + number, currency, locale=EN_US, format_type='accounting', + ) + assert get_text(cell) == expected + +@pytest.mark.parametrize('date,style_name', testutil.combine_values( + [datetime.date(1980, 2, 5), datetime.date(2030, 10, 30)], + XML_NAMES_LIST, +)) +def test_ods_writer_date_cell(ods_writer, date, style_name): + if style_name is None: + expect_style = ods_writer.style_date.getAttribute('name') + cell = ods_writer.date_cell(date) + else: + expect_style = style_name + cell = ods_writer.date_cell(date, stylename=style_name) + date_s = date.isoformat() + assert cell.getAttribute('valuetype') == 'date' + assert cell.getAttribute('datevalue') == date_s + assert cell.getAttribute('stylename') == expect_style + assert get_text(cell) == date_s + +@pytest.mark.parametrize('cell_source,style_name', testutil.combine_values( + NUMERIC_CELL_DATA, + XML_NAMES, +)) +def test_ods_writer_float_cell(ods_writer, cell_source, style_name): + cell = ods_writer.float_cell(cell_source, stylename=style_name) + assert cell.getAttribute('valuetype') == 'float' + assert cell.getAttribute('stylename') == style_name + expected = str(cell_source) + assert cell.getAttribute('value') == expected + assert get_text(cell) == expected + +def test_ods_writer_multiline_cell(ods_writer): + cell = ods_writer.multiline_cell(iter(STRING_CELL_DATA)) + assert cell.getAttribute('valuetype') == 'string' + children = get_children(cell, odf.text.P) + for expected, child in itertools.zip_longest(STRING_CELL_DATA, children): + assert get_text(child) == expected + +@pytest.mark.parametrize('cell_source,style_name', testutil.combine_values( + LINK_CELL_DATA, + XML_NAMES, +)) +def test_ods_writer_multilink_singleton(ods_writer, cell_source, style_name): + cell = ods_writer.multilink_cell([cell_source], stylename=style_name) + assert cell.getAttribute('valuetype') == 'string' + assert cell.getAttribute('stylename') == style_name + try: + href, text = cell_source + except ValueError: + href = cell_source + text = None + anchor = get_child(cell, odf.text.A, type='simple', href=href) + assert get_text(anchor) == (text or '') + +def test_ods_writer_multilink_cell(ods_writer): + cell = ods_writer.multilink_cell(iter(LINK_CELL_DATA)) + assert cell.getAttribute('valuetype') == 'string' + children = get_children(cell, odf.text.A) + for source, child in itertools.zip_longest(LINK_CELL_DATA, children): + try: + href, text = source + except ValueError: + href = source + text = None + assert child.getAttribute('type') == 'simple' + assert child.getAttribute('href') == href + assert get_text(child) == (text or '') + +@pytest.mark.parametrize('cell_source,style_name', testutil.combine_values( + STRING_CELL_DATA, + XML_NAMES, +)) +def test_ods_writer_string_cell(ods_writer, cell_source, style_name): + cell = ods_writer.string_cell(cell_source, stylename=style_name) + assert cell.getAttribute('valuetype') == 'string' + assert cell.getAttribute('stylename') == style_name + assert get_text(cell) == str(cell_source) diff --git a/tests/testutil.py b/tests/testutil.py index e4ee7fc..18b4195 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -22,8 +22,13 @@ import beancount.core.amount as bc_amount import beancount.core.data as bc_data import beancount.loader as bc_loader +import odf.element +import odf.opendocument +import odf.table + from decimal import Decimal from pathlib import Path +from typing import Any, Optional, NamedTuple from conservancy_beancount import books, rtutil @@ -34,6 +39,31 @@ FY_MID_DATE = datetime.date(2020, 9, 1) PAST_DATE = datetime.date(2000, 1, 1) TESTS_DIR = Path(__file__).parent +def _ods_cell_value_type(cell): + assert cell.tagName == 'table:table-cell' + return cell.getAttribute('valuetype') + +def _ods_cell_value(cell): + value_type = cell.getAttribute('valuetype') + if value_type == 'currency' or value_type == 'float': + return Decimal(cell.getAttribute('value')) + elif value_type == 'date': + return datetime.datetime.strptime( + cell.getAttribute('datevalue'), '%Y-%m-%d', + ).date() + else: + return cell.getAttribute('value') + +def _ods_elem_text(elem): + if isinstance(elem, odf.element.Text): + return elem.data + else: + return '\0'.join(_ods_elem_text(child) for child in elem.childNodes) + +odf.element.Element.value_type = property(_ods_cell_value_type) +odf.element.Element.value = property(_ods_cell_value) +odf.element.Element.text = property(_ods_elem_text) + def check_lines_match(lines, expect_patterns, source='output'): for pattern in expect_patterns: assert any(re.search(pattern, line) for line in lines), \ @@ -156,6 +186,22 @@ OPENING_EQUITY_ACCOUNTS = itertools.cycle([ 'Equity:OpeningBalance', ]) +class ODSCell: + @classmethod + def from_row(cls, row): + return row.getElementsByType(odf.table.TableCell) + + @classmethod + def from_sheet(cls, spreadsheet): + for row in spreadsheet.getElementsByType(odf.table.TableRow): + yield list(cls.from_row(row)) + + @classmethod + def from_ods_file(cls, path): + ods = odf.opendocument.load(path) + return cls.from_sheet(ods.spreadsheet) + + def OpeningBalance(acct=None, **txn_meta): if acct is None: acct = next(OPENING_EQUITY_ACCOUNTS)