Merge branch 'batch_cache'

This commit is contained in:
Christopher Neugebauer 2016-05-01 19:12:53 +10:00
commit ded5114073
9 changed files with 430 additions and 179 deletions

View file

@ -0,0 +1,119 @@
import contextlib
import functools
from django.contrib.auth.models import User
class BatchController(object):
''' Batches are sets of operations where certain queries for users may be
repeated, but are also unlikely change within the boundaries of the batch.
Batches are keyed per-user. You can mark the edge of the batch with the
``batch`` context manager. If you nest calls to ``batch``, only the
outermost call will have the effect of ending the batch.
Batches store results for functions wrapped with ``memoise``. These results
for the user are flushed at the end of the batch.
If a return for a memoised function has a callable attribute called
``end_batch``, that attribute will be called at the end of the batch.
'''
_user_caches = {}
_NESTING_KEY = "nesting_count"
@classmethod
@contextlib.contextmanager
def batch(cls, user):
''' Marks the entry point for a batch for the given user. '''
cls._enter_batch_context(user)
try:
yield
finally:
# Make sure we clean up in case of errors.
cls._exit_batch_context(user)
@classmethod
def _enter_batch_context(cls, user):
if user not in cls._user_caches:
cls._user_caches[user] = cls._new_cache()
cache = cls._user_caches[user]
cache[cls._NESTING_KEY] += 1
@classmethod
def _exit_batch_context(cls, user):
cache = cls._user_caches[user]
cache[cls._NESTING_KEY] -= 1
if cache[cls._NESTING_KEY] == 0:
cls._call_end_batch_methods(user)
del cls._user_caches[user]
@classmethod
def _call_end_batch_methods(cls, user):
cache = cls._user_caches[user]
ended = set()
while True:
keys = set(cache.keys())
if ended == keys:
break
keys_to_end = keys - ended
for key in keys_to_end:
item = cache[key]
if hasattr(item, 'end_batch') and callable(item.end_batch):
item.end_batch()
ended = ended | keys_to_end
@classmethod
def memoise(cls, func):
''' Decorator that stores the result of the stored function in the
user's results cache until the batch completes. Keyword arguments are
not yet supported.
Arguments:
func (callable(*a)): The function whose results we want
to store. The positional arguments, ``a``, are used as cache
keys.
Returns:
callable(*a): The memosing version of ``func``.
'''
@functools.wraps(func)
def f(*a):
for arg in a:
if isinstance(arg, User):
user = arg
break
else:
raise ValueError("One position argument must be a User")
func_key = (func, tuple(a))
cache = cls.get_cache(user)
if func_key not in cache:
cache[func_key] = func(*a)
return cache[func_key]
return f
@classmethod
def get_cache(cls, user):
if user not in cls._user_caches:
# Return blank cache here, we'll just discard :)
return cls._new_cache()
return cls._user_caches[user]
@classmethod
def _new_cache(cls):
''' Returns a new cache dictionary. '''
cache = {}
cache[cls._NESTING_KEY] = 0
return cache

View file

@ -16,6 +16,7 @@ from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from registrasion.models import inventory from registrasion.models import inventory
from.batch import BatchController
from .category import CategoryController from .category import CategoryController
from .discount import DiscountController from .discount import DiscountController
from .flag import FlagController from .flag import FlagController
@ -34,10 +35,11 @@ def _modifies_cart(func):
def inner(self, *a, **k): def inner(self, *a, **k):
self._fail_if_cart_is_not_active() self._fail_if_cart_is_not_active()
with transaction.atomic(): with transaction.atomic():
with CartController.operations_batch(self.cart.user) as mark: with BatchController.batch(self.cart.user):
mark.mark = True # Marker that we've modified the cart # Mark the version of self in the batch cache as modified
memoised = self.for_user(self.cart.user)
memoised._modified_by_batch = True
return func(self, *a, **k) return func(self, *a, **k)
return inner return inner
@ -47,6 +49,7 @@ class CartController(object):
self.cart = cart self.cart = cart
@classmethod @classmethod
@BatchController.memoise
def for_user(cls, user): def for_user(cls, user):
''' Returns the user's current cart, or creates a new cart ''' Returns the user's current cart, or creates a new cart
if there isn't one ready yet. ''' if there isn't one ready yet. '''
@ -64,59 +67,6 @@ class CartController(object):
) )
return cls(existing) return cls(existing)
# Marks the carts that are currently in batches
_FOR_USER = {}
_BATCH_COUNT = collections.defaultdict(int)
_MODIFIED_CARTS = set()
class _ModificationMarker(object):
pass
@classmethod
@contextlib.contextmanager
def operations_batch(cls, user):
''' Marks the boundary for a batch of operations on a user's cart.
These markers can be nested. Only on exiting the outermost marker will
a batch be ended.
When a batch is ended, discounts are recalculated, and the cart's
revision is increased.
'''
if user not in cls._FOR_USER:
_ctrl = cls.for_user(user)
cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
ctrl, _id = cls._FOR_USER[user]
cls._BATCH_COUNT[_id] += 1
try:
success = False
marker = cls._ModificationMarker()
yield marker
if hasattr(marker, "mark"):
cls._MODIFIED_CARTS.add(_id)
success = True
finally:
cls._BATCH_COUNT[_id] -= 1
# Only end on the outermost batch marker, and only if
# it excited cleanly, and a modification occurred
modified = _id in cls._MODIFIED_CARTS
outermost = cls._BATCH_COUNT[_id] == 0
if modified and outermost and success:
ctrl._end_batch()
cls._MODIFIED_CARTS.remove(_id)
# Clear out the cache on the outermost operation
if outermost:
del cls._FOR_USER[user]
def _fail_if_cart_is_not_active(self): def _fail_if_cart_is_not_active(self):
self.cart.refresh_from_db() self.cart.refresh_from_db()
if self.cart.status != commerce.Cart.STATUS_ACTIVE: if self.cart.status != commerce.Cart.STATUS_ACTIVE:
@ -144,6 +94,13 @@ class CartController(object):
self.cart.time_last_updated = timezone.now() self.cart.time_last_updated = timezone.now()
self.cart.reservation_duration = max(reservations) self.cart.reservation_duration = max(reservations)
def end_batch(self):
''' Calls ``_end_batch`` if a modification has been performed in the
previous batch. '''
if hasattr(self,'_modified_by_batch'):
self._end_batch()
def _end_batch(self): def _end_batch(self):
''' Performs operations that occur occur at the end of a batch of ''' Performs operations that occur occur at the end of a batch of
product changes/voucher applications etc. product changes/voucher applications etc.
@ -217,16 +174,14 @@ class CartController(object):
errors = [] errors = []
# Pre-annotate products # Pre-annotate products
products = [p for (p, q) in product_quantities] remainders = ProductController.user_remainders(self.cart.user)
r = ProductController.attach_user_remainders(self.cart.user, products)
with_remainders = dict((p, p) for p in r)
# Test each product limit here # Test each product limit here
for product, quantity in product_quantities: for product, quantity in product_quantities:
if quantity < 0: if quantity < 0:
errors.append((product, "Value must be zero or greater.")) errors.append((product, "Value must be zero or greater."))
limit = with_remainders[product].remainder limit = remainders[product.id]
if quantity > limit: if quantity > limit:
errors.append(( errors.append((
@ -242,12 +197,11 @@ class CartController(object):
by_cat[product.category].append((product, quantity)) by_cat[product.category].append((product, quantity))
# Pre-annotate categories # Pre-annotate categories
r = CategoryController.attach_user_remainders(self.cart.user, by_cat) remainders = CategoryController.user_remainders(self.cart.user)
with_remainders = dict((cat, cat) for cat in r)
# Test each category limit here # Test each category limit here
for category in by_cat: for category in by_cat:
limit = with_remainders[category].remainder limit = remainders[category.id]
# Get the amount so far in the cart # Get the amount so far in the cart
to_add = sum(i[1] for i in by_cat[category]) to_add = sum(i[1] for i in by_cat[category])

View file

@ -7,6 +7,7 @@ from django.db.models import Sum
from django.db.models import When from django.db.models import When
from django.db.models import Value from django.db.models import Value
from .batch import BatchController
class AllProducts(object): class AllProducts(object):
pass pass
@ -39,17 +40,17 @@ class CategoryController(object):
return set(i.category for i in available) return set(i.category for i in available)
@classmethod @classmethod
def attach_user_remainders(cls, user, categories): @BatchController.memoise
def user_remainders(cls, user):
''' '''
Return: Return:
queryset(inventory.Product): A queryset containing items from Mapping[int->int]: A dictionary that maps the category ID to the
``categories``, with an extra attribute -- remainder = the amount user's remainder for that category.
of items from this category that is remaining.
''' '''
ids = [category.id for category in categories] categories = inventory.Category.objects.all()
categories = inventory.Category.objects.filter(id__in=ids)
cart_filter = ( cart_filter = (
Q(product__productitem__cart__user=user) & Q(product__productitem__cart__user=user) &
@ -73,12 +74,4 @@ class CategoryController(object):
categories = categories.annotate(remainder=remainder) categories = categories.annotate(remainder=remainder)
return categories return dict((cat.id, cat.remainder) for cat in categories)
def user_quantity_remaining(self, user):
''' Returns the quantity of this product that the user add in the
current cart. '''
with_remainders = self.attach_user_remainders(user, [self.category])
return with_remainders[0].remainder

View file

@ -1,6 +1,8 @@
import itertools import itertools
from conditions import ConditionController from .batch import BatchController
from .conditions import ConditionController
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
@ -10,7 +12,6 @@ from django.db.models import Sum
from django.db.models import Value from django.db.models import Value
from django.db.models import When from django.db.models import When
class DiscountAndQuantity(object): class DiscountAndQuantity(object):
''' Represents a discount that can be applied to a product or category ''' Represents a discount that can be applied to a product or category
for a given user. for a given user.
@ -50,7 +51,22 @@ class DiscountController(object):
categories and products. The discounts also list the available quantity categories and products. The discounts also list the available quantity
for this user, not including products that are pending purchase. ''' for this user, not including products that are pending purchase. '''
filtered_clauses = cls._filtered_discounts(user, categories, products) filtered_clauses = cls._filtered_clauses(user)
# clauses that match provided categories
categories = set(categories)
# clauses that match provided products
products = set(products)
# clauses that match categories for provided products
product_categories = set(product.category for product in products)
# (Not relevant: clauses that match products in provided categories)
all_categories = categories | product_categories
filtered_clauses = (
clause for clause in filtered_clauses
if hasattr(clause, 'product') and clause.product in products or
hasattr(clause, 'category') and clause.category in all_categories
)
discounts = [] discounts = []
@ -84,12 +100,13 @@ class DiscountController(object):
return discounts return discounts
@classmethod @classmethod
def _filtered_discounts(cls, user, categories, products): @BatchController.memoise
def _filtered_clauses(cls, user):
''' '''
Returns: Returns:
Sequence[discountbase]: All discounts that passed the filter Sequence[DiscountForProduct | DiscountForCategory]: All clauses
function. that passed the filter function.
''' '''
@ -98,42 +115,22 @@ class DiscountController(object):
i for i in types if issubclass(i, conditions.DiscountBase) i for i in types if issubclass(i, conditions.DiscountBase)
] ]
# discounts that match provided categories product_clauses = conditions.DiscountForProduct.objects.all()
category_discounts = conditions.DiscountForCategory.objects.filter( product_clauses = product_clauses.select_related(
category__in=categories "discount",
)
# discounts that match provided products
product_discounts = conditions.DiscountForProduct.objects.filter(
product__in=products
)
# discounts that match categories for provided products
product_category_discounts = conditions.DiscountForCategory.objects
product_category_discounts = product_category_discounts.filter(
category__in=(product.category for product in products)
)
# (Not relevant: discounts that match products in provided categories)
product_discounts = product_discounts.select_related(
"product", "product",
"product__category", "product__category",
) )
category_clauses = conditions.DiscountForCategory.objects.all()
all_category_discounts = ( category_clauses = category_clauses.select_related(
category_discounts | product_category_discounts
)
all_category_discounts = all_category_discounts.select_related(
"category", "category",
) "discount",
valid_discounts = conditions.DiscountBase.objects.filter(
Q(discountforproduct__in=product_discounts) |
Q(discountforcategory__in=all_category_discounts)
) )
all_subsets = [] all_subsets = []
for discounttype in discounttypes: for discounttype in discounttypes:
discounts = discounttype.objects.filter(id__in=valid_discounts) discounts = discounttype.objects.all()
ctrl = ConditionController.for_type(discounttype) ctrl = ConditionController.for_type(discounttype)
discounts = ctrl.pre_filter(discounts, user) discounts = ctrl.pre_filter(discounts, user)
all_subsets.append(discounts) all_subsets.append(discounts)
@ -145,8 +142,8 @@ class DiscountController(object):
from_filter = dict((i.id, i) for i in filtered_discounts) from_filter = dict((i.id, i) for i in filtered_discounts)
clause_sets = ( clause_sets = (
product_discounts.filter(discount__in=filtered_discounts), product_clauses.filter(discount__in=filtered_discounts),
all_category_discounts.filter(discount__in=filtered_discounts), category_clauses.filter(discount__in=filtered_discounts),
) )
clause_sets = ( clause_sets = (

View file

@ -6,6 +6,7 @@ from collections import namedtuple
from django.db.models import Count from django.db.models import Count
from django.db.models import Q from django.db.models import Q
from .batch import BatchController
from .conditions import ConditionController from .conditions import ConditionController
from registrasion.models import conditions from registrasion.models import conditions
@ -47,8 +48,6 @@ class FlagController(object):
a list is returned containing all of the products that are *not a list is returned containing all of the products that are *not
enabled*. ''' enabled*. '''
print "GREPME: test_flags()"
if products is not None and product_quantities is not None: if products is not None and product_quantities is not None:
raise ValueError("Please specify only products or " raise ValueError("Please specify only products or "
"product_quantities") "product_quantities")
@ -62,7 +61,7 @@ class FlagController(object):
if products: if products:
# Simplify the query. # Simplify the query.
all_conditions = cls._filtered_flags(user, products) all_conditions = cls._filtered_flags(user)
else: else:
all_conditions = [] all_conditions = []
@ -86,6 +85,8 @@ class FlagController(object):
# from the categories covered by this condition # from the categories covered by this condition
ids = [product.id for product in products] ids = [product.id for product in products]
# TODO: This is re-evaluated a lot.
all_products = inventory.Product.objects.filter(id__in=ids) all_products = inventory.Product.objects.filter(id__in=ids)
cond = ( cond = (
Q(flagbase_set=condition) | Q(flagbase_set=condition) |
@ -117,7 +118,7 @@ class FlagController(object):
if not met and product not in messages: if not met and product not in messages:
messages[product] = message messages[product] = message
total_flags = FlagCounter.count() total_flags = FlagCounter.count(user)
valid = {} valid = {}
@ -160,7 +161,8 @@ class FlagController(object):
return error_fields return error_fields
@classmethod @classmethod
def _filtered_flags(cls, user, products): @BatchController.memoise
def _filtered_flags(cls, user):
''' '''
Returns: Returns:
@ -171,26 +173,15 @@ class FlagController(object):
types = list(ConditionController._controllers()) types = list(ConditionController._controllers())
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
# Get all flags for the products and categories.
prods = (
product.flagbase_set.all()
for product in products
)
cats = (
category.flagbase_set.all()
for category in set(product.category for product in products)
)
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
all_subsets = [] all_subsets = []
for flagtype in flagtypes: for flagtype in flagtypes:
flags = flagtype.objects.filter(id__in=all_flags) flags = flagtype.objects.all()
ctrl = ConditionController.for_type(flagtype) ctrl = ConditionController.for_type(flagtype)
flags = ctrl.pre_filter(flags, user) flags = ctrl.pre_filter(flags, user)
all_subsets.append(flags) all_subsets.append(flags)
return itertools.chain(*all_subsets) return list(itertools.chain(*all_subsets))
ConditionAndRemainder = namedtuple( ConditionAndRemainder = namedtuple(
@ -220,11 +211,11 @@ _ConditionsCount = namedtuple(
) )
# TODO: this should be cacheable.
class FlagCounter(_FlagCounter): class FlagCounter(_FlagCounter):
@classmethod @classmethod
def count(cls): @BatchController.memoise
def count(cls, user):
# Get the count of how many conditions should exist per product # Get the count of how many conditions should exist per product
flagbases = conditions.FlagBase.objects flagbases = conditions.FlagBase.objects

View file

@ -9,6 +9,7 @@ from django.db.models import Value
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import inventory from registrasion.models import inventory
from .batch import BatchController
from .category import CategoryController from .category import CategoryController
from .flag import FlagController from .flag import FlagController
@ -34,18 +35,14 @@ class ProductController(object):
if products is not None: if products is not None:
all_products = set(itertools.chain(all_products, products)) all_products = set(itertools.chain(all_products, products))
categories = set(product.category for product in all_products) category_remainders = CategoryController.user_remainders(user)
r = CategoryController.attach_user_remainders(user, categories) product_remainders = ProductController.user_remainders(user)
cat_quants = dict((c, c) for c in r)
r = ProductController.attach_user_remainders(user, all_products)
prod_quants = dict((p, p) for p in r)
passed_limits = set( passed_limits = set(
product product
for product in all_products for product in all_products
if cat_quants[product.category].remainder > 0 if category_remainders[product.category.id] > 0
if prod_quants[product].remainder > 0 if product_remainders[product.id] > 0
) )
failed_and_messages = FlagController.test_flags( failed_and_messages = FlagController.test_flags(
@ -59,17 +56,16 @@ class ProductController(object):
return out return out
@classmethod @classmethod
def attach_user_remainders(cls, user, products): @BatchController.memoise
def user_remainders(cls, user):
''' '''
Return: Return:
queryset(inventory.Product): A queryset containing items from Mapping[int->int]: A dictionary that maps the product ID to the
``product``, with an extra attribute -- remainder = the amount of user's remainder for that product.
this item that is remaining.
''' '''
ids = [product.id for product in products] products = inventory.Product.objects.all()
products = inventory.Product.objects.filter(id__in=ids)
cart_filter = ( cart_filter = (
Q(productitem__cart__user=user) & Q(productitem__cart__user=user) &
@ -93,12 +89,4 @@ class ProductController(object):
products = products.annotate(remainder=remainder) products = products.annotate(remainder=remainder)
return products return dict((product.id, product.remainder) for product in products)
def user_quantity_remaining(self, user):
''' Returns the quantity of this product that the user add in the
current cart. '''
with_remainders = self.attach_user_remainders(user, [self.product])
return with_remainders[0].remainder

View file

@ -0,0 +1,144 @@
import datetime
import pytz
from django.core.exceptions import ValidationError
from controller_helpers import TestingCartController
from test_cart import RegistrationCartTestCase
from registrasion.controllers.batch import BatchController
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.product import ProductController
from registrasion.models import commerce
from registrasion.models import conditions
UTC = pytz.timezone('UTC')
class BatchTestCase(RegistrationCartTestCase):
def test_no_caches_outside_of_batches(self):
cache_1 = BatchController.get_cache(self.USER_1)
cache_2 = BatchController.get_cache(self.USER_2)
# Identity testing is important here
self.assertIsNot(cache_1, cache_2)
def test_cache_clears_at_batch_exit(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
cache_2 = BatchController.get_cache(self.USER_1)
self.assertIsNot(cache_1, cache_2)
def test_caches_identical_within_nestings(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_1)
cache_3 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_1, cache_2)
self.assertIs(cache_2, cache_3)
def test_caches_are_independent_for_different_users(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_2)
self.assertIsNot(cache_1, cache_2)
def test_cache_clears_are_independent_for_different_users(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_2):
cache_2 = BatchController.get_cache(self.USER_2)
with BatchController.batch(self.USER_2):
cache_3 = BatchController.get_cache(self.USER_2)
cache_4 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_1, cache_4)
self.assertIsNot(cache_1, cache_2)
self.assertIsNot(cache_2, cache_3)
def test_new_caches_for_new_batches(self):
with BatchController.batch(self.USER_1):
cache_1 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_1):
cache_2 = BatchController.get_cache(self.USER_1)
with BatchController.batch(self.USER_1):
cache_3 = BatchController.get_cache(self.USER_1)
self.assertIs(cache_2, cache_3)
self.assertIsNot(cache_1, cache_2)
def test_memoisation_happens_in_batch_context(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_1):
output_2 = self._memoiseme(self.USER_1)
self.assertIs(output_1, output_2)
def test_memoisaion_does_not_happen_outside_batch_context(self):
output_1 = self._memoiseme(self.USER_1)
output_2 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
def test_memoisation_is_user_independent(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_2):
output_2 = self._memoiseme(self.USER_2)
output_3 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
self.assertIs(output_1, output_3)
def test_memoisation_clears_outside_batches(self):
with BatchController.batch(self.USER_1):
output_1 = self._memoiseme(self.USER_1)
with BatchController.batch(self.USER_1):
output_2 = self._memoiseme(self.USER_1)
self.assertIsNot(output_1, output_2)
@classmethod
@BatchController.memoise
def _memoiseme(self, user):
return object()
def test_batch_end_functionality_is_called(self):
class Ender(object):
end_count = 0
def end_batch(self):
self.end_count += 1
@BatchController.memoise
def get_ender(user):
return Ender()
# end_batch should get called once on exiting the batch
with BatchController.batch(self.USER_1):
ender = get_ender(self.USER_1)
self.assertEquals(1, ender.end_count)
# end_batch should get called once on exiting the batch
# no matter how deep the object gets cached
with BatchController.batch(self.USER_1):
with BatchController.batch(self.USER_1):
ender = get_ender(self.USER_1)
self.assertEquals(1, ender.end_count)

View file

@ -12,6 +12,7 @@ from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from registrasion.models import inventory from registrasion.models import inventory
from registrasion.models import people from registrasion.models import people
from registrasion.controllers.batch import BatchController
from registrasion.controllers.product import ProductController from registrasion.controllers.product import ProductController
from controller_helpers import TestingCartController from controller_helpers import TestingCartController
@ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase):
def test_available_products_respects_product_limits(self): def test_available_products_respects_product_limits(self):
self.__available_products_test(self.PROD_4, 6) self.__available_products_test(self.PROD_4, 6)
def test_cart_controller_for_user_is_memoised(self):
# - that for_user is memoised
with BatchController.batch(self.USER_1):
cart = TestingCartController.for_user(self.USER_1)
cart_2 = TestingCartController.for_user(self.USER_1)
self.assertIs(cart, cart_2)
def test_cart_revision_does_not_increment_if_not_modified(self):
cart = TestingCartController.for_user(self.USER_1)
rev_0 = cart.cart.revision
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
# Do nothing on exit
rev_1 = self.reget(cart.cart).revision
self.assertEqual(rev_0, rev_1)
def test_cart_revision_only_increments_at_end_of_batches(self):
cart = TestingCartController.for_user(self.USER_1)
rev_0 = cart.cart.revision
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
same_cart.add_to_cart(self.PROD_1, 1)
rev_1 = self.reget(same_cart.cart).revision
rev_2 = self.reget(cart.cart).revision
self.assertEqual(rev_0, rev_1)
self.assertNotEqual(rev_0, rev_2)
def test_cart_discounts_only_calculated_at_end_of_batches(self):
def count_discounts(cart):
return cart.cart.discountitem_set.count()
cart = TestingCartController.for_user(self.USER_1)
self.make_discount_ceiling("FLOOZLE")
count_0 = count_discounts(cart)
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart = TestingCartController.for_user(self.USER_1)
with BatchController.batch(self.USER_1):
# Memoise the cart
same_cart_2 = TestingCartController.for_user(self.USER_1)
same_cart_2.add_to_cart(self.PROD_1, 1)
count_1 = count_discounts(same_cart_2)
count_2 = count_discounts(same_cart)
count_3 = count_discounts(cart)
self.assertEqual(0, count_0)
self.assertEqual(0, count_1)
self.assertEqual(0, count_2)
self.assertEqual(1, count_3)

View file

@ -5,9 +5,10 @@ from registrasion import util
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import inventory from registrasion.models import inventory
from registrasion.models import people from registrasion.models import people
from registrasion.controllers.discount import DiscountController from registrasion.controllers.batch import BatchController
from registrasion.controllers.cart import CartController from registrasion.controllers.cart import CartController
from registrasion.controllers.credit_note import CreditNoteController from registrasion.controllers.credit_note import CreditNoteController
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.invoice import InvoiceController from registrasion.controllers.invoice import InvoiceController
from registrasion.controllers.product import ProductController from registrasion.controllers.product import ProductController
from registrasion.exceptions import CartValidationError from registrasion.exceptions import CartValidationError
@ -170,6 +171,7 @@ def guided_registration(request):
category__in=cats, category__in=cats,
).select_related("category") ).select_related("category")
with BatchController.batch(request.user):
available_products = set(ProductController.available_products( available_products = set(ProductController.available_products(
request.user, request.user,
products=all_products, products=all_products,
@ -181,7 +183,6 @@ def guided_registration(request):
attendee.save() attendee.save()
return next_step return next_step
with CartController.operations_batch(request.user):
for category in cats: for category in cats:
products = [ products = [
i for i in available_products i for i in available_products
@ -345,6 +346,7 @@ def product_category(request, category_id):
category_id = int(category_id) # Routing is [0-9]+ category_id = int(category_id) # Routing is [0-9]+
category = inventory.Category.objects.get(pk=category_id) category = inventory.Category.objects.get(pk=category_id)
with BatchController.batch(request.user):
products = ProductController.available_products( products = ProductController.available_products(
request.user, request.user,
category=category, category=category,