Merge branch 'batch_cache'

This commit is contained in:
Christopher Neugebauer 2016-05-01 19:12:53 +10:00
commit ded5114073
9 changed files with 430 additions and 179 deletions

View file

@ -0,0 +1,119 @@
import contextlib
import functools
from django.contrib.auth.models import User
class BatchController(object):
''' Batches are sets of operations where certain queries for users may be
repeated, but are also unlikely change within the boundaries of the batch.
Batches are keyed per-user. You can mark the edge of the batch with the
``batch`` context manager. If you nest calls to ``batch``, only the
outermost call will have the effect of ending the batch.
Batches store results for functions wrapped with ``memoise``. These results
for the user are flushed at the end of the batch.
If a return for a memoised function has a callable attribute called
``end_batch``, that attribute will be called at the end of the batch.
'''
_user_caches = {}
_NESTING_KEY = "nesting_count"
@classmethod
@contextlib.contextmanager
def batch(cls, user):
''' Marks the entry point for a batch for the given user. '''
cls._enter_batch_context(user)
try:
yield
finally:
# Make sure we clean up in case of errors.
cls._exit_batch_context(user)
@classmethod
def _enter_batch_context(cls, user):
if user not in cls._user_caches:
cls._user_caches[user] = cls._new_cache()
cache = cls._user_caches[user]
cache[cls._NESTING_KEY] += 1
@classmethod
def _exit_batch_context(cls, user):
cache = cls._user_caches[user]
cache[cls._NESTING_KEY] -= 1
if cache[cls._NESTING_KEY] == 0:
cls._call_end_batch_methods(user)
del cls._user_caches[user]
@classmethod
def _call_end_batch_methods(cls, user):
cache = cls._user_caches[user]
ended = set()
while True:
keys = set(cache.keys())
if ended == keys:
break
keys_to_end = keys - ended
for key in keys_to_end:
item = cache[key]
if hasattr(item, 'end_batch') and callable(item.end_batch):
item.end_batch()
ended = ended | keys_to_end
@classmethod
def memoise(cls, func):
''' Decorator that stores the result of the stored function in the
user's results cache until the batch completes. Keyword arguments are
not yet supported.
Arguments:
func (callable(*a)): The function whose results we want
to store. The positional arguments, ``a``, are used as cache
keys.
Returns:
callable(*a): The memosing version of ``func``.
'''
@functools.wraps(func)
def f(*a):
for arg in a:
if isinstance(arg, User):
user = arg
break
else:
raise ValueError("One position argument must be a User")
func_key = (func, tuple(a))
cache = cls.get_cache(user)
if func_key not in cache:
cache[func_key] = func(*a)
return cache[func_key]
return f
@classmethod
def get_cache(cls, user):
if user not in cls._user_caches:
# Return blank cache here, we'll just discard :)
return cls._new_cache()
return cls._user_caches[user]
@classmethod
def _new_cache(cls):
''' Returns a new cache dictionary. '''
cache = {}
cache[cls._NESTING_KEY] = 0
return cache

View file

@ -16,6 +16,7 @@ from registrasion.models import commerce
from registrasion.models import conditions
from registrasion.models import inventory
from.batch import BatchController
from .category import CategoryController
from .discount import DiscountController
from .flag import FlagController
@ -34,10 +35,11 @@ def _modifies_cart(func):
def inner(self, *a, **k):
self._fail_if_cart_is_not_active()
with transaction.atomic():
with CartController.operations_batch(self.cart.user) as mark:
mark.mark = True # Marker that we've modified the cart
with BatchController.batch(self.cart.user):
# Mark the version of self in the batch cache as modified
memoised = self.for_user(self.cart.user)
memoised._modified_by_batch = True
return func(self, *a, **k)
return inner
@ -47,6 +49,7 @@ class CartController(object):
self.cart = cart
@classmethod
@BatchController.memoise
def for_user(cls, user):
''' Returns the user's current cart, or creates a new cart
if there isn't one ready yet. '''
@ -64,59 +67,6 @@ class CartController(object):
)
return cls(existing)
# Marks the carts that are currently in batches
_FOR_USER = {}
_BATCH_COUNT = collections.defaultdict(int)
_MODIFIED_CARTS = set()
class _ModificationMarker(object):
pass
@classmethod
@contextlib.contextmanager
def operations_batch(cls, user):
''' Marks the boundary for a batch of operations on a user's cart.
These markers can be nested. Only on exiting the outermost marker will
a batch be ended.
When a batch is ended, discounts are recalculated, and the cart's
revision is increased.
'''
if user not in cls._FOR_USER:
_ctrl = cls.for_user(user)
cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
ctrl, _id = cls._FOR_USER[user]
cls._BATCH_COUNT[_id] += 1
try:
success = False
marker = cls._ModificationMarker()
yield marker
if hasattr(marker, "mark"):
cls._MODIFIED_CARTS.add(_id)
success = True
finally:
cls._BATCH_COUNT[_id] -= 1
# Only end on the outermost batch marker, and only if
# it excited cleanly, and a modification occurred
modified = _id in cls._MODIFIED_CARTS
outermost = cls._BATCH_COUNT[_id] == 0
if modified and outermost and success:
ctrl._end_batch()
cls._MODIFIED_CARTS.remove(_id)
# Clear out the cache on the outermost operation
if outermost:
del cls._FOR_USER[user]
def _fail_if_cart_is_not_active(self):
self.cart.refresh_from_db()
if self.cart.status != commerce.Cart.STATUS_ACTIVE:
@ -144,6 +94,13 @@ class CartController(object):
self.cart.time_last_updated = timezone.now()
self.cart.reservation_duration = max(reservations)
def end_batch(self):
''' Calls ``_end_batch`` if a modification has been performed in the
previous batch. '''
if hasattr(self,'_modified_by_batch'):
self._end_batch()
def _end_batch(self):
''' Performs operations that occur occur at the end of a batch of
product changes/voucher applications etc.
@ -217,16 +174,14 @@ class CartController(object):
errors = []
# Pre-annotate products
products = [p for (p, q) in product_quantities]
r = ProductController.attach_user_remainders(self.cart.user, products)
with_remainders = dict((p, p) for p in r)
remainders = ProductController.user_remainders(self.cart.user)
# Test each product limit here
for product, quantity in product_quantities:
if quantity < 0:
errors.append((product, "Value must be zero or greater."))
limit = with_remainders[product].remainder
limit = remainders[product.id]
if quantity > limit:
errors.append((
@ -242,12 +197,11 @@ class CartController(object):
by_cat[product.category].append((product, quantity))
# Pre-annotate categories
r = CategoryController.attach_user_remainders(self.cart.user, by_cat)
with_remainders = dict((cat, cat) for cat in r)
remainders = CategoryController.user_remainders(self.cart.user)
# Test each category limit here
for category in by_cat:
limit = with_remainders[category].remainder
limit = remainders[category.id]
# Get the amount so far in the cart
to_add = sum(i[1] for i in by_cat[category])

View file

@ -7,6 +7,7 @@ from django.db.models import Sum
from django.db.models import When
from django.db.models import Value
from .batch import BatchController
class AllProducts(object):
pass
@ -39,17 +40,17 @@ class CategoryController(object):
return set(i.category for i in available)
@classmethod
def attach_user_remainders(cls, user, categories):
@BatchController.memoise
def user_remainders(cls, user):
'''
Return:
queryset(inventory.Product): A queryset containing items from
``categories``, with an extra attribute -- remainder = the amount
of items from this category that is remaining.
Mapping[int->int]: A dictionary that maps the category ID to the
user's remainder for that category.
'''
ids = [category.id for category in categories]
categories = inventory.Category.objects.filter(id__in=ids)
categories = inventory.Category.objects.all()
cart_filter = (
Q(product__productitem__cart__user=user) &
@ -73,12 +74,4 @@ class CategoryController(object):
categories = categories.annotate(remainder=remainder)
return categories
def user_quantity_remaining(self, user):
''' Returns the quantity of this product that the user add in the
current cart. '''
with_remainders = self.attach_user_remainders(user, [self.category])
return with_remainders[0].remainder
return dict((cat.id, cat.remainder) for cat in categories)

View file

@ -1,6 +1,8 @@
import itertools
from conditions import ConditionController
from .batch import BatchController
from .conditions import ConditionController
from registrasion.models import commerce
from registrasion.models import conditions
@ -10,7 +12,6 @@ from django.db.models import Sum
from django.db.models import Value
from django.db.models import When
class DiscountAndQuantity(object):
''' Represents a discount that can be applied to a product or category
for a given user.
@ -50,7 +51,22 @@ class DiscountController(object):
categories and products. The discounts also list the available quantity
for this user, not including products that are pending purchase. '''
filtered_clauses = cls._filtered_discounts(user, categories, products)
filtered_clauses = cls._filtered_clauses(user)
# clauses that match provided categories
categories = set(categories)
# clauses that match provided products
products = set(products)
# clauses that match categories for provided products
product_categories = set(product.category for product in products)
# (Not relevant: clauses that match products in provided categories)
all_categories = categories | product_categories
filtered_clauses = (
clause for clause in filtered_clauses
if hasattr(clause, 'product') and clause.product in products or
hasattr(clause, 'category') and clause.category in all_categories
)
discounts = []
@ -84,12 +100,13 @@ class DiscountController(object):
return discounts
@classmethod
def _filtered_discounts(cls, user, categories, products):
@BatchController.memoise
def _filtered_clauses(cls, user):
'''
Returns:
Sequence[discountbase]: All discounts that passed the filter
function.
Sequence[DiscountForProduct | DiscountForCategory]: All clauses
that passed the filter function.
'''
@ -98,42 +115,22 @@ class DiscountController(object):
i for i in types if issubclass(i, conditions.DiscountBase)
]
# discounts that match provided categories
category_discounts = conditions.DiscountForCategory.objects.filter(
category__in=categories
)
# discounts that match provided products
product_discounts = conditions.DiscountForProduct.objects.filter(
product__in=products
)
# discounts that match categories for provided products
product_category_discounts = conditions.DiscountForCategory.objects
product_category_discounts = product_category_discounts.filter(
category__in=(product.category for product in products)
)
# (Not relevant: discounts that match products in provided categories)
product_discounts = product_discounts.select_related(
product_clauses = conditions.DiscountForProduct.objects.all()
product_clauses = product_clauses.select_related(
"discount",
"product",
"product__category",
)
all_category_discounts = (
category_discounts | product_category_discounts
)
all_category_discounts = all_category_discounts.select_related(
category_clauses = conditions.DiscountForCategory.objects.all()
category_clauses = category_clauses.select_related(
"category",
)
valid_discounts = conditions.DiscountBase.objects.filter(
Q(discountforproduct__in=product_discounts) |
Q(discountforcategory__in=all_category_discounts)
"discount",
)
all_subsets = []
for discounttype in discounttypes:
discounts = discounttype.objects.filter(id__in=valid_discounts)
discounts = discounttype.objects.all()
ctrl = ConditionController.for_type(discounttype)
discounts = ctrl.pre_filter(discounts, user)
all_subsets.append(discounts)
@ -145,8 +142,8 @@ class DiscountController(object):
from_filter = dict((i.id, i) for i in filtered_discounts)
clause_sets = (
product_discounts.filter(discount__in=filtered_discounts),
all_category_discounts.filter(discount__in=filtered_discounts),
product_clauses.filter(discount__in=filtered_discounts),
category_clauses.filter(discount__in=filtered_discounts),
)
clause_sets = (

View file

@ -6,6 +6,7 @@ from collections import namedtuple
from django.db.models import Count
from django.db.models import Q
from .batch import BatchController
from .conditions import ConditionController
from registrasion.models import conditions
@ -47,8 +48,6 @@ class FlagController(object):
a list is returned containing all of the products that are *not
enabled*. '''
print "GREPME: test_flags()"
if products is not None and product_quantities is not None:
raise ValueError("Please specify only products or "
"product_quantities")
@ -62,7 +61,7 @@ class FlagController(object):
if products:
# Simplify the query.
all_conditions = cls._filtered_flags(user, products)
all_conditions = cls._filtered_flags(user)
else:
all_conditions = []
@ -86,6 +85,8 @@ class FlagController(object):
# from the categories covered by this condition
ids = [product.id for product in products]
# TODO: This is re-evaluated a lot.
all_products = inventory.Product.objects.filter(id__in=ids)
cond = (
Q(flagbase_set=condition) |
@ -117,7 +118,7 @@ class FlagController(object):
if not met and product not in messages:
messages[product] = message
total_flags = FlagCounter.count()
total_flags = FlagCounter.count(user)
valid = {}
@ -160,7 +161,8 @@ class FlagController(object):
return error_fields
@classmethod
def _filtered_flags(cls, user, products):
@BatchController.memoise
def _filtered_flags(cls, user):
'''
Returns:
@ -171,26 +173,15 @@ class FlagController(object):
types = list(ConditionController._controllers())
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
# Get all flags for the products and categories.
prods = (
product.flagbase_set.all()
for product in products
)
cats = (
category.flagbase_set.all()
for category in set(product.category for product in products)
)
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
all_subsets = []
for flagtype in flagtypes:
flags = flagtype.objects.filter(id__in=all_flags)
flags = flagtype.objects.all()
ctrl = ConditionController.for_type(flagtype)
flags = ctrl.pre_filter(flags, user)
all_subsets.append(flags)
return itertools.chain(*all_subsets)
return list(itertools.chain(*all_subsets))
ConditionAndRemainder = namedtuple(
@ -220,11 +211,11 @@ _ConditionsCount = namedtuple(
)
# TODO: this should be cacheable.
class FlagCounter(_FlagCounter):
@classmethod
def count(cls):
@BatchController.memoise
def count(cls, user):
# Get the count of how many conditions should exist per product
flagbases = conditions.FlagBase.objects

View file

@ -9,6 +9,7 @@ from django.db.models import Value
from registrasion.models import commerce
from registrasion.models import inventory
from .batch import BatchController
from .category import CategoryController
from .flag import FlagController
@ -34,18 +35,14 @@ class ProductController(object):
if products is not None:
all_products = set(itertools.chain(all_products, products))
categories = set(product.category for product in all_products)
r = CategoryController.attach_user_remainders(user, categories)
cat_quants = dict((c, c) for c in r)
r = ProductController.attach_user_remainders(user, all_products)
prod_quants = dict((p, p) for p in r)
category_remainders = CategoryController.user_remainders(user)
product_remainders = ProductController.user_remainders(user)
passed_limits = set(
product
for product in all_products
if cat_quants[product.category].remainder > 0
if prod_quants[product].remainder > 0
if category_remainders[product.category.id] > 0
if product_remainders[product.id] > 0
)
failed_and_messages = FlagController.test_flags(
@ -59,17 +56,16 @@ class ProductController(object):
return out
@classmethod
def attach_user_remainders(cls, user, products):
@BatchController.memoise
def user_remainders(cls, user):
'''
Return:
queryset(inventory.Product): A queryset containing items from
``product``, with an extra attribute -- remainder = the amount of
this item that is remaining.
Mapping[int->int]: A dictionary that maps the product ID to the
user's remainder for that product.
'''
ids = [product.id for product in products]
products = inventory.Product.objects.filter(id__in=ids)
products = inventory.Product.objects.all()
cart_filter = (
Q(productitem__cart__user=user) &
@ -93,12 +89,4 @@ class ProductController(object):
products = products.annotate(remainder=remainder)
return products
def user_quantity_remaining(self, user):
''' Returns the quantity of this product that the user add in the
current cart. '''
with_remainders = self.attach_user_remainders(user, [self.product])
return with_remainders[0].remainder
return dict((product.id, product.remainder) for product in products)

View file

@ -0,0 +1,144 @@
import datetime
import pytz
from django.core.exceptions import ValidationError
from controller_helpers import TestingCartController
from test_cart import RegistrationCartTestCase
from registrasion.controllers.batch import BatchController
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.product import ProductController
from registrasion.models import commerce
from registrasion.models import conditions
UTC = pytz.timezone('UTC')
class BatchTestCase(RegistrationCartTestCase):
def test_no_caches_outside_of_batches(self):
cache_1 = BatchController.get_cache(self.USER_1)
cache_2 = BatchController.get_cache(self.USER_2)
# Identity testing is important here
self.assertIsNot(cache_1, cache_2)
def test_cache_clears_at_batch_exit(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
cache_2 = BatchController.get_cache(self.USER_1)
self.assertIsNot(cache_1, cache_2)
def test_caches_identical_within_nestings(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_1)
cache_3 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_1, cache_2)
self.assertIs(cache_2, cache_3)
def test_caches_are_independent_for_different_users(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_2)
self.assertIsNot(cache_1, cache_2)
def test_cache_clears_are_independent_for_different_users(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_2)
with BatchController.batch(self.USER_2):
cache_3 = BatchController.get_cache(self.USER_2)
cache_4 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_1, cache_4)
self.assertIsNot(cache_1, cache_2)
self.assertIsNot(cache_2, cache_3)
def test_new_caches_for_new_batches(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_1):
cache_2 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_1):
cache_3 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_2, cache_3)
self.assertIsNot(cache_1, cache_2)
def test_memoisation_happens_in_batch_context(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_1):
output_2 = self._memoiseme(self.USER_1)
self.assertIs(output_1, output_2)
def test_memoisaion_does_not_happen_outside_batch_context(self):
output_1 = self._memoiseme(self.USER_1)
output_2 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
def test_memoisation_is_user_independent(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_2):
output_2 = self._memoiseme(self.USER_2)
output_3 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
self.assertIs(output_1, output_3)
def test_memoisation_clears_outside_batches(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_1):
output_2 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
@classmethod
@BatchController.memoise
def _memoiseme(self, user):
return object()
def test_batch_end_functionality_is_called(self):
class Ender(object):
end_count = 0
def end_batch(self):
self.end_count += 1
@BatchController.memoise
def get_ender(user):
return Ender()
# end_batch should get called once on exiting the batch
with BatchController.batch(self.USER_1):
ender = get_ender(self.USER_1)
self.assertEquals(1, ender.end_count)
# end_batch should get called once on exiting the batch
# no matter how deep the object gets cached
with BatchController.batch(self.USER_1):
with BatchController.batch(self.USER_1):
ender = get_ender(self.USER_1)
self.assertEquals(1, ender.end_count)

View file

@ -12,6 +12,7 @@ from registrasion.models import commerce
from registrasion.models import conditions
from registrasion.models import inventory
from registrasion.models import people
from registrasion.controllers.batch import BatchController
from registrasion.controllers.product import ProductController
from controller_helpers import TestingCartController
@ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase):
def test_available_products_respects_product_limits(self):
self.__available_products_test(self.PROD_4, 6)
def test_cart_controller_for_user_is_memoised(self):
# - that for_user is memoised
with BatchController.batch(self.USER_1):
cart = TestingCartController.for_user(self.USER_1)
cart_2 = TestingCartController.for_user(self.USER_1)
self.assertIs(cart, cart_2)
def test_cart_revision_does_not_increment_if_not_modified(self):
cart = TestingCartController.for_user(self.USER_1)
rev_0 = cart.cart.revision
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
# Do nothing on exit
rev_1 = self.reget(cart.cart).revision
self.assertEqual(rev_0, rev_1)
def test_cart_revision_only_increments_at_end_of_batches(self):
cart = TestingCartController.for_user(self.USER_1)
rev_0 = cart.cart.revision
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
same_cart.add_to_cart(self.PROD_1, 1)
rev_1 = self.reget(same_cart.cart).revision
rev_2 = self.reget(cart.cart).revision
self.assertEqual(rev_0, rev_1)
self.assertNotEqual(rev_0, rev_2)
def test_cart_discounts_only_calculated_at_end_of_batches(self):
def count_discounts(cart):
return cart.cart.discountitem_set.count()
cart = TestingCartController.for_user(self.USER_1)
self.make_discount_ceiling("FLOOZLE")
count_0 = count_discounts(cart)
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart_2 = TestingCartController.for_user(self.USER_1)
same_cart_2.add_to_cart(self.PROD_1, 1)
count_1 = count_discounts(same_cart_2)
count_2 = count_discounts(same_cart)
count_3 = count_discounts(cart)
self.assertEqual(0, count_0)
self.assertEqual(0, count_1)
self.assertEqual(0, count_2)
self.assertEqual(1, count_3)

View file

@ -5,9 +5,10 @@ from registrasion import util
from registrasion.models import commerce
from registrasion.models import inventory
from registrasion.models import people
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.batch import BatchController
from registrasion.controllers.cart import CartController
from registrasion.controllers.credit_note import CreditNoteController
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.invoice import InvoiceController
from registrasion.controllers.product import ProductController
from registrasion.exceptions import CartValidationError
@ -170,6 +171,7 @@ def guided_registration(request):
category__in=cats,
).select_related("category")
with BatchController.batch(request.user):
available_products = set(ProductController.available_products(
request.user,
products=all_products,
@ -181,7 +183,6 @@ def guided_registration(request):
attendee.save()
return next_step
with CartController.operations_batch(request.user):
for category in cats:
products = [
i for i in available_products
@ -345,6 +346,7 @@ def product_category(request, category_id):
category_id = int(category_id) # Routing is [0-9]+
category = inventory.Category.objects.get(pk=category_id)
with BatchController.batch(request.user):
products = ProductController.available_products(
request.user,
category=category,