Merge branch 'batch_cache'
This commit is contained in:
commit
ded5114073
9 changed files with 430 additions and 179 deletions
119
registrasion/controllers/batch.py
Normal file
119
registrasion/controllers/batch.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import contextlib
|
||||
import functools
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
|
||||
class BatchController(object):
|
||||
''' Batches are sets of operations where certain queries for users may be
|
||||
repeated, but are also unlikely change within the boundaries of the batch.
|
||||
|
||||
Batches are keyed per-user. You can mark the edge of the batch with the
|
||||
``batch`` context manager. If you nest calls to ``batch``, only the
|
||||
outermost call will have the effect of ending the batch.
|
||||
|
||||
Batches store results for functions wrapped with ``memoise``. These results
|
||||
for the user are flushed at the end of the batch.
|
||||
|
||||
If a return for a memoised function has a callable attribute called
|
||||
``end_batch``, that attribute will be called at the end of the batch.
|
||||
|
||||
'''
|
||||
|
||||
_user_caches = {}
|
||||
_NESTING_KEY = "nesting_count"
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def batch(cls, user):
|
||||
''' Marks the entry point for a batch for the given user. '''
|
||||
|
||||
cls._enter_batch_context(user)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Make sure we clean up in case of errors.
|
||||
cls._exit_batch_context(user)
|
||||
|
||||
@classmethod
|
||||
def _enter_batch_context(cls, user):
|
||||
if user not in cls._user_caches:
|
||||
cls._user_caches[user] = cls._new_cache()
|
||||
|
||||
cache = cls._user_caches[user]
|
||||
cache[cls._NESTING_KEY] += 1
|
||||
|
||||
@classmethod
|
||||
def _exit_batch_context(cls, user):
|
||||
cache = cls._user_caches[user]
|
||||
cache[cls._NESTING_KEY] -= 1
|
||||
|
||||
if cache[cls._NESTING_KEY] == 0:
|
||||
cls._call_end_batch_methods(user)
|
||||
del cls._user_caches[user]
|
||||
|
||||
@classmethod
|
||||
def _call_end_batch_methods(cls, user):
|
||||
cache = cls._user_caches[user]
|
||||
ended = set()
|
||||
while True:
|
||||
keys = set(cache.keys())
|
||||
if ended == keys:
|
||||
break
|
||||
keys_to_end = keys - ended
|
||||
for key in keys_to_end:
|
||||
item = cache[key]
|
||||
if hasattr(item, 'end_batch') and callable(item.end_batch):
|
||||
item.end_batch()
|
||||
ended = ended | keys_to_end
|
||||
|
||||
@classmethod
|
||||
def memoise(cls, func):
|
||||
''' Decorator that stores the result of the stored function in the
|
||||
user's results cache until the batch completes. Keyword arguments are
|
||||
not yet supported.
|
||||
|
||||
Arguments:
|
||||
func (callable(*a)): The function whose results we want
|
||||
to store. The positional arguments, ``a``, are used as cache
|
||||
keys.
|
||||
|
||||
Returns:
|
||||
callable(*a): The memosing version of ``func``.
|
||||
|
||||
'''
|
||||
|
||||
@functools.wraps(func)
|
||||
def f(*a):
|
||||
|
||||
for arg in a:
|
||||
if isinstance(arg, User):
|
||||
user = arg
|
||||
break
|
||||
else:
|
||||
raise ValueError("One position argument must be a User")
|
||||
|
||||
func_key = (func, tuple(a))
|
||||
cache = cls.get_cache(user)
|
||||
|
||||
if func_key not in cache:
|
||||
cache[func_key] = func(*a)
|
||||
|
||||
return cache[func_key]
|
||||
|
||||
return f
|
||||
|
||||
@classmethod
|
||||
def get_cache(cls, user):
|
||||
if user not in cls._user_caches:
|
||||
# Return blank cache here, we'll just discard :)
|
||||
return cls._new_cache()
|
||||
|
||||
return cls._user_caches[user]
|
||||
|
||||
@classmethod
|
||||
def _new_cache(cls):
|
||||
''' Returns a new cache dictionary. '''
|
||||
cache = {}
|
||||
cache[cls._NESTING_KEY] = 0
|
||||
return cache
|
|
@ -16,6 +16,7 @@ from registrasion.models import commerce
|
|||
from registrasion.models import conditions
|
||||
from registrasion.models import inventory
|
||||
|
||||
from.batch import BatchController
|
||||
from .category import CategoryController
|
||||
from .discount import DiscountController
|
||||
from .flag import FlagController
|
||||
|
@ -34,10 +35,11 @@ def _modifies_cart(func):
|
|||
def inner(self, *a, **k):
|
||||
self._fail_if_cart_is_not_active()
|
||||
with transaction.atomic():
|
||||
with CartController.operations_batch(self.cart.user) as mark:
|
||||
mark.mark = True # Marker that we've modified the cart
|
||||
with BatchController.batch(self.cart.user):
|
||||
# Mark the version of self in the batch cache as modified
|
||||
memoised = self.for_user(self.cart.user)
|
||||
memoised._modified_by_batch = True
|
||||
return func(self, *a, **k)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
|
@ -47,6 +49,7 @@ class CartController(object):
|
|||
self.cart = cart
|
||||
|
||||
@classmethod
|
||||
@BatchController.memoise
|
||||
def for_user(cls, user):
|
||||
''' Returns the user's current cart, or creates a new cart
|
||||
if there isn't one ready yet. '''
|
||||
|
@ -64,59 +67,6 @@ class CartController(object):
|
|||
)
|
||||
return cls(existing)
|
||||
|
||||
# Marks the carts that are currently in batches
|
||||
_FOR_USER = {}
|
||||
_BATCH_COUNT = collections.defaultdict(int)
|
||||
_MODIFIED_CARTS = set()
|
||||
|
||||
class _ModificationMarker(object):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def operations_batch(cls, user):
|
||||
''' Marks the boundary for a batch of operations on a user's cart.
|
||||
|
||||
These markers can be nested. Only on exiting the outermost marker will
|
||||
a batch be ended.
|
||||
|
||||
When a batch is ended, discounts are recalculated, and the cart's
|
||||
revision is increased.
|
||||
'''
|
||||
|
||||
if user not in cls._FOR_USER:
|
||||
_ctrl = cls.for_user(user)
|
||||
cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
|
||||
|
||||
ctrl, _id = cls._FOR_USER[user]
|
||||
|
||||
cls._BATCH_COUNT[_id] += 1
|
||||
try:
|
||||
success = False
|
||||
|
||||
marker = cls._ModificationMarker()
|
||||
yield marker
|
||||
|
||||
if hasattr(marker, "mark"):
|
||||
cls._MODIFIED_CARTS.add(_id)
|
||||
|
||||
success = True
|
||||
finally:
|
||||
|
||||
cls._BATCH_COUNT[_id] -= 1
|
||||
|
||||
# Only end on the outermost batch marker, and only if
|
||||
# it excited cleanly, and a modification occurred
|
||||
modified = _id in cls._MODIFIED_CARTS
|
||||
outermost = cls._BATCH_COUNT[_id] == 0
|
||||
if modified and outermost and success:
|
||||
ctrl._end_batch()
|
||||
cls._MODIFIED_CARTS.remove(_id)
|
||||
|
||||
# Clear out the cache on the outermost operation
|
||||
if outermost:
|
||||
del cls._FOR_USER[user]
|
||||
|
||||
def _fail_if_cart_is_not_active(self):
|
||||
self.cart.refresh_from_db()
|
||||
if self.cart.status != commerce.Cart.STATUS_ACTIVE:
|
||||
|
@ -144,6 +94,13 @@ class CartController(object):
|
|||
self.cart.time_last_updated = timezone.now()
|
||||
self.cart.reservation_duration = max(reservations)
|
||||
|
||||
|
||||
def end_batch(self):
|
||||
''' Calls ``_end_batch`` if a modification has been performed in the
|
||||
previous batch. '''
|
||||
if hasattr(self,'_modified_by_batch'):
|
||||
self._end_batch()
|
||||
|
||||
def _end_batch(self):
|
||||
''' Performs operations that occur occur at the end of a batch of
|
||||
product changes/voucher applications etc.
|
||||
|
@ -217,16 +174,14 @@ class CartController(object):
|
|||
errors = []
|
||||
|
||||
# Pre-annotate products
|
||||
products = [p for (p, q) in product_quantities]
|
||||
r = ProductController.attach_user_remainders(self.cart.user, products)
|
||||
with_remainders = dict((p, p) for p in r)
|
||||
remainders = ProductController.user_remainders(self.cart.user)
|
||||
|
||||
# Test each product limit here
|
||||
for product, quantity in product_quantities:
|
||||
if quantity < 0:
|
||||
errors.append((product, "Value must be zero or greater."))
|
||||
|
||||
limit = with_remainders[product].remainder
|
||||
limit = remainders[product.id]
|
||||
|
||||
if quantity > limit:
|
||||
errors.append((
|
||||
|
@ -242,12 +197,11 @@ class CartController(object):
|
|||
by_cat[product.category].append((product, quantity))
|
||||
|
||||
# Pre-annotate categories
|
||||
r = CategoryController.attach_user_remainders(self.cart.user, by_cat)
|
||||
with_remainders = dict((cat, cat) for cat in r)
|
||||
remainders = CategoryController.user_remainders(self.cart.user)
|
||||
|
||||
# Test each category limit here
|
||||
for category in by_cat:
|
||||
limit = with_remainders[category].remainder
|
||||
limit = remainders[category.id]
|
||||
|
||||
# Get the amount so far in the cart
|
||||
to_add = sum(i[1] for i in by_cat[category])
|
||||
|
|
|
@ -7,6 +7,7 @@ from django.db.models import Sum
|
|||
from django.db.models import When
|
||||
from django.db.models import Value
|
||||
|
||||
from .batch import BatchController
|
||||
|
||||
class AllProducts(object):
|
||||
pass
|
||||
|
@ -39,17 +40,17 @@ class CategoryController(object):
|
|||
return set(i.category for i in available)
|
||||
|
||||
@classmethod
|
||||
def attach_user_remainders(cls, user, categories):
|
||||
@BatchController.memoise
|
||||
def user_remainders(cls, user):
|
||||
'''
|
||||
|
||||
Return:
|
||||
queryset(inventory.Product): A queryset containing items from
|
||||
``categories``, with an extra attribute -- remainder = the amount
|
||||
of items from this category that is remaining.
|
||||
Mapping[int->int]: A dictionary that maps the category ID to the
|
||||
user's remainder for that category.
|
||||
|
||||
'''
|
||||
|
||||
ids = [category.id for category in categories]
|
||||
categories = inventory.Category.objects.filter(id__in=ids)
|
||||
categories = inventory.Category.objects.all()
|
||||
|
||||
cart_filter = (
|
||||
Q(product__productitem__cart__user=user) &
|
||||
|
@ -73,12 +74,4 @@ class CategoryController(object):
|
|||
|
||||
categories = categories.annotate(remainder=remainder)
|
||||
|
||||
return categories
|
||||
|
||||
def user_quantity_remaining(self, user):
|
||||
''' Returns the quantity of this product that the user add in the
|
||||
current cart. '''
|
||||
|
||||
with_remainders = self.attach_user_remainders(user, [self.category])
|
||||
|
||||
return with_remainders[0].remainder
|
||||
return dict((cat.id, cat.remainder) for cat in categories)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import itertools
|
||||
|
||||
from conditions import ConditionController
|
||||
from .batch import BatchController
|
||||
from .conditions import ConditionController
|
||||
|
||||
from registrasion.models import commerce
|
||||
from registrasion.models import conditions
|
||||
|
||||
|
@ -10,7 +12,6 @@ from django.db.models import Sum
|
|||
from django.db.models import Value
|
||||
from django.db.models import When
|
||||
|
||||
|
||||
class DiscountAndQuantity(object):
|
||||
''' Represents a discount that can be applied to a product or category
|
||||
for a given user.
|
||||
|
@ -50,7 +51,22 @@ class DiscountController(object):
|
|||
categories and products. The discounts also list the available quantity
|
||||
for this user, not including products that are pending purchase. '''
|
||||
|
||||
filtered_clauses = cls._filtered_discounts(user, categories, products)
|
||||
filtered_clauses = cls._filtered_clauses(user)
|
||||
|
||||
# clauses that match provided categories
|
||||
categories = set(categories)
|
||||
# clauses that match provided products
|
||||
products = set(products)
|
||||
# clauses that match categories for provided products
|
||||
product_categories = set(product.category for product in products)
|
||||
# (Not relevant: clauses that match products in provided categories)
|
||||
all_categories = categories | product_categories
|
||||
|
||||
filtered_clauses = (
|
||||
clause for clause in filtered_clauses
|
||||
if hasattr(clause, 'product') and clause.product in products or
|
||||
hasattr(clause, 'category') and clause.category in all_categories
|
||||
)
|
||||
|
||||
discounts = []
|
||||
|
||||
|
@ -84,12 +100,13 @@ class DiscountController(object):
|
|||
return discounts
|
||||
|
||||
@classmethod
|
||||
def _filtered_discounts(cls, user, categories, products):
|
||||
@BatchController.memoise
|
||||
def _filtered_clauses(cls, user):
|
||||
'''
|
||||
|
||||
Returns:
|
||||
Sequence[discountbase]: All discounts that passed the filter
|
||||
function.
|
||||
Sequence[DiscountForProduct | DiscountForCategory]: All clauses
|
||||
that passed the filter function.
|
||||
|
||||
'''
|
||||
|
||||
|
@ -98,42 +115,22 @@ class DiscountController(object):
|
|||
i for i in types if issubclass(i, conditions.DiscountBase)
|
||||
]
|
||||
|
||||
# discounts that match provided categories
|
||||
category_discounts = conditions.DiscountForCategory.objects.filter(
|
||||
category__in=categories
|
||||
)
|
||||
# discounts that match provided products
|
||||
product_discounts = conditions.DiscountForProduct.objects.filter(
|
||||
product__in=products
|
||||
)
|
||||
# discounts that match categories for provided products
|
||||
product_category_discounts = conditions.DiscountForCategory.objects
|
||||
product_category_discounts = product_category_discounts.filter(
|
||||
category__in=(product.category for product in products)
|
||||
)
|
||||
# (Not relevant: discounts that match products in provided categories)
|
||||
|
||||
product_discounts = product_discounts.select_related(
|
||||
product_clauses = conditions.DiscountForProduct.objects.all()
|
||||
product_clauses = product_clauses.select_related(
|
||||
"discount",
|
||||
"product",
|
||||
"product__category",
|
||||
)
|
||||
|
||||
all_category_discounts = (
|
||||
category_discounts | product_category_discounts
|
||||
)
|
||||
all_category_discounts = all_category_discounts.select_related(
|
||||
category_clauses = conditions.DiscountForCategory.objects.all()
|
||||
category_clauses = category_clauses.select_related(
|
||||
"category",
|
||||
)
|
||||
|
||||
valid_discounts = conditions.DiscountBase.objects.filter(
|
||||
Q(discountforproduct__in=product_discounts) |
|
||||
Q(discountforcategory__in=all_category_discounts)
|
||||
"discount",
|
||||
)
|
||||
|
||||
all_subsets = []
|
||||
|
||||
for discounttype in discounttypes:
|
||||
discounts = discounttype.objects.filter(id__in=valid_discounts)
|
||||
discounts = discounttype.objects.all()
|
||||
ctrl = ConditionController.for_type(discounttype)
|
||||
discounts = ctrl.pre_filter(discounts, user)
|
||||
all_subsets.append(discounts)
|
||||
|
@ -145,8 +142,8 @@ class DiscountController(object):
|
|||
from_filter = dict((i.id, i) for i in filtered_discounts)
|
||||
|
||||
clause_sets = (
|
||||
product_discounts.filter(discount__in=filtered_discounts),
|
||||
all_category_discounts.filter(discount__in=filtered_discounts),
|
||||
product_clauses.filter(discount__in=filtered_discounts),
|
||||
category_clauses.filter(discount__in=filtered_discounts),
|
||||
)
|
||||
|
||||
clause_sets = (
|
||||
|
|
|
@ -6,6 +6,7 @@ from collections import namedtuple
|
|||
from django.db.models import Count
|
||||
from django.db.models import Q
|
||||
|
||||
from .batch import BatchController
|
||||
from .conditions import ConditionController
|
||||
|
||||
from registrasion.models import conditions
|
||||
|
@ -47,8 +48,6 @@ class FlagController(object):
|
|||
a list is returned containing all of the products that are *not
|
||||
enabled*. '''
|
||||
|
||||
print "GREPME: test_flags()"
|
||||
|
||||
if products is not None and product_quantities is not None:
|
||||
raise ValueError("Please specify only products or "
|
||||
"product_quantities")
|
||||
|
@ -62,7 +61,7 @@ class FlagController(object):
|
|||
|
||||
if products:
|
||||
# Simplify the query.
|
||||
all_conditions = cls._filtered_flags(user, products)
|
||||
all_conditions = cls._filtered_flags(user)
|
||||
else:
|
||||
all_conditions = []
|
||||
|
||||
|
@ -86,6 +85,8 @@ class FlagController(object):
|
|||
# from the categories covered by this condition
|
||||
|
||||
ids = [product.id for product in products]
|
||||
|
||||
# TODO: This is re-evaluated a lot.
|
||||
all_products = inventory.Product.objects.filter(id__in=ids)
|
||||
cond = (
|
||||
Q(flagbase_set=condition) |
|
||||
|
@ -117,7 +118,7 @@ class FlagController(object):
|
|||
if not met and product not in messages:
|
||||
messages[product] = message
|
||||
|
||||
total_flags = FlagCounter.count()
|
||||
total_flags = FlagCounter.count(user)
|
||||
|
||||
valid = {}
|
||||
|
||||
|
@ -160,7 +161,8 @@ class FlagController(object):
|
|||
return error_fields
|
||||
|
||||
@classmethod
|
||||
def _filtered_flags(cls, user, products):
|
||||
@BatchController.memoise
|
||||
def _filtered_flags(cls, user):
|
||||
'''
|
||||
|
||||
Returns:
|
||||
|
@ -171,26 +173,15 @@ class FlagController(object):
|
|||
types = list(ConditionController._controllers())
|
||||
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
|
||||
|
||||
# Get all flags for the products and categories.
|
||||
prods = (
|
||||
product.flagbase_set.all()
|
||||
for product in products
|
||||
)
|
||||
cats = (
|
||||
category.flagbase_set.all()
|
||||
for category in set(product.category for product in products)
|
||||
)
|
||||
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
|
||||
|
||||
all_subsets = []
|
||||
|
||||
for flagtype in flagtypes:
|
||||
flags = flagtype.objects.filter(id__in=all_flags)
|
||||
flags = flagtype.objects.all()
|
||||
ctrl = ConditionController.for_type(flagtype)
|
||||
flags = ctrl.pre_filter(flags, user)
|
||||
all_subsets.append(flags)
|
||||
|
||||
return itertools.chain(*all_subsets)
|
||||
return list(itertools.chain(*all_subsets))
|
||||
|
||||
|
||||
ConditionAndRemainder = namedtuple(
|
||||
|
@ -220,11 +211,11 @@ _ConditionsCount = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
# TODO: this should be cacheable.
|
||||
class FlagCounter(_FlagCounter):
|
||||
|
||||
@classmethod
|
||||
def count(cls):
|
||||
@BatchController.memoise
|
||||
def count(cls, user):
|
||||
# Get the count of how many conditions should exist per product
|
||||
flagbases = conditions.FlagBase.objects
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from django.db.models import Value
|
|||
from registrasion.models import commerce
|
||||
from registrasion.models import inventory
|
||||
|
||||
from .batch import BatchController
|
||||
from .category import CategoryController
|
||||
from .flag import FlagController
|
||||
|
||||
|
@ -34,18 +35,14 @@ class ProductController(object):
|
|||
if products is not None:
|
||||
all_products = set(itertools.chain(all_products, products))
|
||||
|
||||
categories = set(product.category for product in all_products)
|
||||
r = CategoryController.attach_user_remainders(user, categories)
|
||||
cat_quants = dict((c, c) for c in r)
|
||||
|
||||
r = ProductController.attach_user_remainders(user, all_products)
|
||||
prod_quants = dict((p, p) for p in r)
|
||||
category_remainders = CategoryController.user_remainders(user)
|
||||
product_remainders = ProductController.user_remainders(user)
|
||||
|
||||
passed_limits = set(
|
||||
product
|
||||
for product in all_products
|
||||
if cat_quants[product.category].remainder > 0
|
||||
if prod_quants[product].remainder > 0
|
||||
if category_remainders[product.category.id] > 0
|
||||
if product_remainders[product.id] > 0
|
||||
)
|
||||
|
||||
failed_and_messages = FlagController.test_flags(
|
||||
|
@ -59,17 +56,16 @@ class ProductController(object):
|
|||
return out
|
||||
|
||||
@classmethod
|
||||
def attach_user_remainders(cls, user, products):
|
||||
@BatchController.memoise
|
||||
def user_remainders(cls, user):
|
||||
'''
|
||||
|
||||
Return:
|
||||
queryset(inventory.Product): A queryset containing items from
|
||||
``product``, with an extra attribute -- remainder = the amount of
|
||||
this item that is remaining.
|
||||
Mapping[int->int]: A dictionary that maps the product ID to the
|
||||
user's remainder for that product.
|
||||
'''
|
||||
|
||||
ids = [product.id for product in products]
|
||||
products = inventory.Product.objects.filter(id__in=ids)
|
||||
products = inventory.Product.objects.all()
|
||||
|
||||
cart_filter = (
|
||||
Q(productitem__cart__user=user) &
|
||||
|
@ -93,12 +89,4 @@ class ProductController(object):
|
|||
|
||||
products = products.annotate(remainder=remainder)
|
||||
|
||||
return products
|
||||
|
||||
def user_quantity_remaining(self, user):
|
||||
''' Returns the quantity of this product that the user add in the
|
||||
current cart. '''
|
||||
|
||||
with_remainders = self.attach_user_remainders(user, [self.product])
|
||||
|
||||
return with_remainders[0].remainder
|
||||
return dict((product.id, product.remainder) for product in products)
|
||||
|
|
144
registrasion/tests/test_batch.py
Normal file
144
registrasion/tests/test_batch.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
import datetime
|
||||
import pytz
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from controller_helpers import TestingCartController
|
||||
from test_cart import RegistrationCartTestCase
|
||||
|
||||
from registrasion.controllers.batch import BatchController
|
||||
from registrasion.controllers.discount import DiscountController
|
||||
from registrasion.controllers.product import ProductController
|
||||
from registrasion.models import commerce
|
||||
from registrasion.models import conditions
|
||||
|
||||
UTC = pytz.timezone('UTC')
|
||||
|
||||
|
||||
class BatchTestCase(RegistrationCartTestCase):
|
||||
|
||||
def test_no_caches_outside_of_batches(self):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
cache_2 = BatchController.get_cache(self.USER_2)
|
||||
|
||||
# Identity testing is important here
|
||||
self.assertIsNot(cache_1, cache_2)
|
||||
|
||||
def test_cache_clears_at_batch_exit(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
cache_2 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
self.assertIsNot(cache_1, cache_2)
|
||||
|
||||
def test_caches_identical_within_nestings(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_2):
|
||||
cache_2 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
cache_3 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
self.assertIs(cache_1, cache_2)
|
||||
self.assertIs(cache_2, cache_3)
|
||||
|
||||
def test_caches_are_independent_for_different_users(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_2):
|
||||
cache_2 = BatchController.get_cache(self.USER_2)
|
||||
|
||||
self.assertIsNot(cache_1, cache_2)
|
||||
|
||||
def test_cache_clears_are_independent_for_different_users(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_2):
|
||||
cache_2 = BatchController.get_cache(self.USER_2)
|
||||
|
||||
with BatchController.batch(self.USER_2):
|
||||
cache_3 = BatchController.get_cache(self.USER_2)
|
||||
|
||||
cache_4 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
self.assertIs(cache_1, cache_4)
|
||||
self.assertIsNot(cache_1, cache_2)
|
||||
self.assertIsNot(cache_2, cache_3)
|
||||
|
||||
def test_new_caches_for_new_batches(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_1 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_2 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
cache_3 = BatchController.get_cache(self.USER_1)
|
||||
|
||||
self.assertIs(cache_2, cache_3)
|
||||
self.assertIsNot(cache_1, cache_2)
|
||||
|
||||
def test_memoisation_happens_in_batch_context(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
output_1 = self._memoiseme(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
output_2 = self._memoiseme(self.USER_1)
|
||||
|
||||
self.assertIs(output_1, output_2)
|
||||
|
||||
def test_memoisaion_does_not_happen_outside_batch_context(self):
|
||||
output_1 = self._memoiseme(self.USER_1)
|
||||
output_2 = self._memoiseme(self.USER_1)
|
||||
|
||||
self.assertIsNot(output_1, output_2)
|
||||
|
||||
def test_memoisation_is_user_independent(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
output_1 = self._memoiseme(self.USER_1)
|
||||
with BatchController.batch(self.USER_2):
|
||||
output_2 = self._memoiseme(self.USER_2)
|
||||
output_3 = self._memoiseme(self.USER_1)
|
||||
|
||||
self.assertIsNot(output_1, output_2)
|
||||
self.assertIs(output_1, output_3)
|
||||
|
||||
def test_memoisation_clears_outside_batches(self):
|
||||
with BatchController.batch(self.USER_1):
|
||||
output_1 = self._memoiseme(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
output_2 = self._memoiseme(self.USER_1)
|
||||
|
||||
self.assertIsNot(output_1, output_2)
|
||||
|
||||
@classmethod
|
||||
@BatchController.memoise
|
||||
def _memoiseme(self, user):
|
||||
return object()
|
||||
|
||||
def test_batch_end_functionality_is_called(self):
|
||||
class Ender(object):
|
||||
end_count = 0
|
||||
def end_batch(self):
|
||||
self.end_count += 1
|
||||
|
||||
@BatchController.memoise
|
||||
def get_ender(user):
|
||||
return Ender()
|
||||
|
||||
# end_batch should get called once on exiting the batch
|
||||
with BatchController.batch(self.USER_1):
|
||||
ender = get_ender(self.USER_1)
|
||||
self.assertEquals(1, ender.end_count)
|
||||
|
||||
# end_batch should get called once on exiting the batch
|
||||
# no matter how deep the object gets cached
|
||||
with BatchController.batch(self.USER_1):
|
||||
with BatchController.batch(self.USER_1):
|
||||
ender = get_ender(self.USER_1)
|
||||
self.assertEquals(1, ender.end_count)
|
|
@ -12,6 +12,7 @@ from registrasion.models import commerce
|
|||
from registrasion.models import conditions
|
||||
from registrasion.models import inventory
|
||||
from registrasion.models import people
|
||||
from registrasion.controllers.batch import BatchController
|
||||
from registrasion.controllers.product import ProductController
|
||||
|
||||
from controller_helpers import TestingCartController
|
||||
|
@ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase):
|
|||
|
||||
def test_available_products_respects_product_limits(self):
|
||||
self.__available_products_test(self.PROD_4, 6)
|
||||
|
||||
def test_cart_controller_for_user_is_memoised(self):
|
||||
# - that for_user is memoised
|
||||
with BatchController.batch(self.USER_1):
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
cart_2 = TestingCartController.for_user(self.USER_1)
|
||||
self.assertIs(cart, cart_2)
|
||||
|
||||
def test_cart_revision_does_not_increment_if_not_modified(self):
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
rev_0 = cart.cart.revision
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
# Memoise the cart
|
||||
same_cart = TestingCartController.for_user(self.USER_1)
|
||||
# Do nothing on exit
|
||||
|
||||
rev_1 = self.reget(cart.cart).revision
|
||||
self.assertEqual(rev_0, rev_1)
|
||||
|
||||
def test_cart_revision_only_increments_at_end_of_batches(self):
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
rev_0 = cart.cart.revision
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
# Memoise the cart
|
||||
same_cart = TestingCartController.for_user(self.USER_1)
|
||||
same_cart.add_to_cart(self.PROD_1, 1)
|
||||
rev_1 = self.reget(same_cart.cart).revision
|
||||
|
||||
rev_2 = self.reget(cart.cart).revision
|
||||
|
||||
self.assertEqual(rev_0, rev_1)
|
||||
self.assertNotEqual(rev_0, rev_2)
|
||||
|
||||
def test_cart_discounts_only_calculated_at_end_of_batches(self):
|
||||
def count_discounts(cart):
|
||||
return cart.cart.discountitem_set.count()
|
||||
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
self.make_discount_ceiling("FLOOZLE")
|
||||
count_0 = count_discounts(cart)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
# Memoise the cart
|
||||
same_cart = TestingCartController.for_user(self.USER_1)
|
||||
|
||||
with BatchController.batch(self.USER_1):
|
||||
# Memoise the cart
|
||||
same_cart_2 = TestingCartController.for_user(self.USER_1)
|
||||
|
||||
same_cart_2.add_to_cart(self.PROD_1, 1)
|
||||
count_1 = count_discounts(same_cart_2)
|
||||
|
||||
count_2 = count_discounts(same_cart)
|
||||
|
||||
count_3 = count_discounts(cart)
|
||||
|
||||
self.assertEqual(0, count_0)
|
||||
self.assertEqual(0, count_1)
|
||||
self.assertEqual(0, count_2)
|
||||
self.assertEqual(1, count_3)
|
||||
|
|
|
@ -5,9 +5,10 @@ from registrasion import util
|
|||
from registrasion.models import commerce
|
||||
from registrasion.models import inventory
|
||||
from registrasion.models import people
|
||||
from registrasion.controllers.discount import DiscountController
|
||||
from registrasion.controllers.batch import BatchController
|
||||
from registrasion.controllers.cart import CartController
|
||||
from registrasion.controllers.credit_note import CreditNoteController
|
||||
from registrasion.controllers.discount import DiscountController
|
||||
from registrasion.controllers.invoice import InvoiceController
|
||||
from registrasion.controllers.product import ProductController
|
||||
from registrasion.exceptions import CartValidationError
|
||||
|
@ -170,6 +171,7 @@ def guided_registration(request):
|
|||
category__in=cats,
|
||||
).select_related("category")
|
||||
|
||||
with BatchController.batch(request.user):
|
||||
available_products = set(ProductController.available_products(
|
||||
request.user,
|
||||
products=all_products,
|
||||
|
@ -181,7 +183,6 @@ def guided_registration(request):
|
|||
attendee.save()
|
||||
return next_step
|
||||
|
||||
with CartController.operations_batch(request.user):
|
||||
for category in cats:
|
||||
products = [
|
||||
i for i in available_products
|
||||
|
@ -345,6 +346,7 @@ def product_category(request, category_id):
|
|||
category_id = int(category_id) # Routing is [0-9]+
|
||||
category = inventory.Category.objects.get(pk=category_id)
|
||||
|
||||
with BatchController.batch(request.user):
|
||||
products = ProductController.available_products(
|
||||
request.user,
|
||||
category=category,
|
||||
|
|
Loading…
Reference in a new issue