query: Convert query functions that return List to Set.
Beancount's built-in renderers expect this and are better equipped for it.
This commit is contained in:
parent
25321a81b0
commit
ef03893bfe
2 changed files with 27 additions and 34 deletions
|
@ -157,16 +157,14 @@ class PostingContext:
|
||||||
class MetaDocs(bc_query_env.AnyMeta):
|
class MetaDocs(bc_query_env.AnyMeta):
|
||||||
"""Return a list of document links from metadata."""
|
"""Return a list of document links from metadata."""
|
||||||
def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None:
|
def __init__(self, operands: List[bc_query_compile.EvalNode]) -> None:
|
||||||
super(bc_query_env.AnyMeta, self).__init__(operands, list)
|
super(bc_query_env.AnyMeta, self).__init__(operands, set)
|
||||||
# The second argument is our return type.
|
# The second argument is our return type.
|
||||||
# It should match the annotated return type of __call__.
|
# It should match the annotated return type of __call__.
|
||||||
|
|
||||||
def __call__(self, context: PostingContext) -> List[str]:
|
def __call__(self, context: PostingContext) -> Set[str]:
|
||||||
raw_value = super().__call__(context)
|
raw_value = super().__call__(context)
|
||||||
if isinstance(raw_value, str):
|
seq = raw_value.split() if isinstance(raw_value, str) else ''
|
||||||
return raw_value.split()
|
return set(seq)
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class RTField(NamedTuple):
|
class RTField(NamedTuple):
|
||||||
|
@ -247,7 +245,7 @@ class RTTicket(bc_query_compile.EvalFunction):
|
||||||
self._meta_key(meta_op.value)
|
self._meta_key(meta_op.value)
|
||||||
if not rest:
|
if not rest:
|
||||||
operands.append(bc_query_compile.EvalConstant(sys.maxsize))
|
operands.append(bc_query_compile.EvalConstant(sys.maxsize))
|
||||||
super().__init__(operands, list)
|
super().__init__(operands, set)
|
||||||
|
|
||||||
def _rt_key(self, key: str) -> RTField:
|
def _rt_key(self, key: str) -> RTField:
|
||||||
try:
|
try:
|
||||||
|
@ -261,7 +259,7 @@ class RTTicket(bc_query_compile.EvalFunction):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"metadata key {key!r} does not contain documentation links")
|
raise ValueError(f"metadata key {key!r} does not contain documentation links")
|
||||||
|
|
||||||
def __call__(self, context: PostingContext) -> list:
|
def __call__(self, context: PostingContext) -> Set[object]:
|
||||||
rt_key: str
|
rt_key: str
|
||||||
meta_key: str
|
meta_key: str
|
||||||
limit: int
|
limit: int
|
||||||
|
@ -283,7 +281,7 @@ class RTTicket(bc_query_compile.EvalFunction):
|
||||||
ticket_ids.add(rt_id[0])
|
ticket_ids.add(rt_id[0])
|
||||||
if len(ticket_ids) >= limit:
|
if len(ticket_ids) >= limit:
|
||||||
break
|
break
|
||||||
retval: List[object] = []
|
retval: Set[object] = set()
|
||||||
for ticket_id in ticket_ids:
|
for ticket_id in ticket_ids:
|
||||||
try:
|
try:
|
||||||
rt_ticket = self._rt_cache[ticket_id]
|
rt_ticket = self._rt_cache[ticket_id]
|
||||||
|
@ -294,9 +292,9 @@ class RTTicket(bc_query_compile.EvalFunction):
|
||||||
if field_value is None:
|
if field_value is None:
|
||||||
pass
|
pass
|
||||||
elif isinstance(field_value, list):
|
elif isinstance(field_value, list):
|
||||||
retval.extend(field_value)
|
retval.update(field_value)
|
||||||
else:
|
else:
|
||||||
retval.append(field_value)
|
retval.add(field_value)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -86,10 +86,10 @@ def test_rt_ticket_bad_metadata(ticket_query, meta_name):
|
||||||
ticket_query(const_operands('id', meta_name))
|
ticket_query(const_operands('id', meta_name))
|
||||||
|
|
||||||
@pytest.mark.parametrize('field_name,meta_name,expected', [
|
@pytest.mark.parametrize('field_name,meta_name,expected', [
|
||||||
('id', 'rt-id', 1),
|
('id', 'rt-id', {1}),
|
||||||
('Queue', 'approval', 'general'),
|
('Queue', 'approval', {'general'}),
|
||||||
('Requestors', 'invoice', ['mx1@example.org', 'requestor2@example.org']),
|
('Requestors', 'invoice', {'mx1@example.org', 'requestor2@example.org'}),
|
||||||
('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)),
|
('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 1, 0, tzinfo=UTC)}),
|
||||||
])
|
])
|
||||||
def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected):
|
def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected):
|
||||||
func = ticket_query(const_operands(field_name, meta_name))
|
func = ticket_query(const_operands(field_name, meta_name))
|
||||||
|
@ -97,15 +97,13 @@ def test_rt_ticket_from_txn(ticket_query, field_name, meta_name, expected):
|
||||||
('Assets:Cash', 80),
|
('Assets:Cash', 80),
|
||||||
])
|
])
|
||||||
context = RowContext(txn, txn.postings[0])
|
context = RowContext(txn, txn.postings[0])
|
||||||
if not isinstance(expected, list):
|
|
||||||
expected = [expected]
|
|
||||||
assert func(context) == expected
|
assert func(context) == expected
|
||||||
|
|
||||||
@pytest.mark.parametrize('field_name,meta_name,expected', [
|
@pytest.mark.parametrize('field_name,meta_name,expected', [
|
||||||
('id', 'rt-id', 2),
|
('id', 'rt-id', {2}),
|
||||||
('Queue', 'approval', 'general'),
|
('Queue', 'approval', {'general'}),
|
||||||
('Requestors', 'invoice', ['mx2@example.org', 'requestor2@example.org']),
|
('Requestors', 'invoice', {'mx2@example.org', 'requestor2@example.org'}),
|
||||||
('Due', 'tax-reporting', datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)),
|
('Due', 'tax-reporting', {datetime.datetime(2017, 1, 14, 12, 2, 0, tzinfo=UTC)}),
|
||||||
])
|
])
|
||||||
def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected):
|
def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected):
|
||||||
func = ticket_query(const_operands(field_name, meta_name))
|
func = ticket_query(const_operands(field_name, meta_name))
|
||||||
|
@ -113,19 +111,16 @@ def test_rt_ticket_from_post(ticket_query, field_name, meta_name, expected):
|
||||||
('Assets:Cash', 110, {meta_name: 'rt:2/8'}),
|
('Assets:Cash', 110, {meta_name: 'rt:2/8'}),
|
||||||
])
|
])
|
||||||
context = RowContext(txn, txn.postings[0])
|
context = RowContext(txn, txn.postings[0])
|
||||||
if not isinstance(expected, list):
|
|
||||||
expected = [expected]
|
|
||||||
assert func(context) == expected
|
assert func(context) == expected
|
||||||
|
|
||||||
@pytest.mark.parametrize('field_name,meta_name,expected,on_txn', [
|
@pytest.mark.parametrize('field_name,meta_name,expected,on_txn', [
|
||||||
('id', 'approval', [1, 2], True),
|
('id', 'approval', {1, 2}, True),
|
||||||
('Queue', 'check', ['general', 'general'], False),
|
('Queue', 'check', {'general'}, False),
|
||||||
('Requestors', 'invoice', [
|
('Requestors', 'invoice', {
|
||||||
'mx1@example.org',
|
'mx1@example.org',
|
||||||
'mx2@example.org',
|
'mx2@example.org',
|
||||||
'requestor2@example.org',
|
'requestor2@example.org',
|
||||||
'requestor2@example.org',
|
}, False),
|
||||||
], False),
|
|
||||||
])
|
])
|
||||||
def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, on_txn):
|
def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected, on_txn):
|
||||||
func = ticket_query(const_operands(field_name, meta_name))
|
func = ticket_query(const_operands(field_name, meta_name))
|
||||||
|
@ -136,7 +131,7 @@ def test_rt_ticket_multi_results(ticket_query, field_name, meta_name, expected,
|
||||||
meta = txn.meta if on_txn else post.meta
|
meta = txn.meta if on_txn else post.meta
|
||||||
meta[meta_name] = 'rt:1/2 Docs/12.pdf rt:2/8'
|
meta[meta_name] = 'rt:1/2 Docs/12.pdf rt:2/8'
|
||||||
context = RowContext(txn, post)
|
context = RowContext(txn, post)
|
||||||
assert sorted(func(context)) == expected
|
assert func(context) == expected
|
||||||
|
|
||||||
@pytest.mark.parametrize('meta_value,on_txn', testutil.combine_values(
|
@pytest.mark.parametrize('meta_value,on_txn', testutil.combine_values(
|
||||||
['', 'Docs/34.pdf', 'Docs/100.pdf Docs/120.pdf'],
|
['', 'Docs/34.pdf', 'Docs/100.pdf Docs/120.pdf'],
|
||||||
|
@ -151,7 +146,7 @@ def test_rt_ticket_no_results(ticket_query, meta_value, on_txn):
|
||||||
meta = txn.meta if on_txn else post.meta
|
meta = txn.meta if on_txn else post.meta
|
||||||
meta['check'] = meta_value
|
meta['check'] = meta_value
|
||||||
context = RowContext(txn, post)
|
context = RowContext(txn, post)
|
||||||
assert func(context) == []
|
assert func(context) == set()
|
||||||
|
|
||||||
def test_rt_ticket_caches_tickets():
|
def test_rt_ticket_caches_tickets():
|
||||||
rt_client = testutil.RTClient()
|
rt_client = testutil.RTClient()
|
||||||
|
@ -162,9 +157,9 @@ def test_rt_ticket_caches_tickets():
|
||||||
('Assets:Cash', 160, {'rt-id': 'rt:3'}),
|
('Assets:Cash', 160, {'rt-id': 'rt:3'}),
|
||||||
])
|
])
|
||||||
context = RowContext(txn, txn.postings[0])
|
context = RowContext(txn, txn.postings[0])
|
||||||
assert func(context) == [3]
|
assert func(context) == {3}
|
||||||
del rt_client.TICKET_DATA['3']
|
del rt_client.TICKET_DATA['3']
|
||||||
assert func(context) == [3]
|
assert func(context) == {3}
|
||||||
|
|
||||||
def test_rt_ticket_caches_tickets_not_found():
|
def test_rt_ticket_caches_tickets_not_found():
|
||||||
rt_client = testutil.RTClient()
|
rt_client = testutil.RTClient()
|
||||||
|
@ -176,9 +171,9 @@ def test_rt_ticket_caches_tickets_not_found():
|
||||||
('Assets:Cash', 160, {'rt-id': 'rt:3'}),
|
('Assets:Cash', 160, {'rt-id': 'rt:3'}),
|
||||||
])
|
])
|
||||||
context = RowContext(txn, txn.postings[0])
|
context = RowContext(txn, txn.postings[0])
|
||||||
assert func(context) == []
|
assert func(context) == set()
|
||||||
rt_client.TICKET_DATA['3'] = rt3
|
rt_client.TICKET_DATA['3'] = rt3
|
||||||
assert func(context) == []
|
assert func(context) == set()
|
||||||
|
|
||||||
def test_books_loader_empty():
|
def test_books_loader_empty():
|
||||||
result = qmod.BooksLoader(None)()
|
result = qmod.BooksLoader(None)()
|
||||||
|
|
Loading…
Reference in a new issue