diff --git a/registrasion/controllers/batch.py b/registrasion/controllers/batch.py new file mode 100644 index 00000000..579d8970 --- /dev/null +++ b/registrasion/controllers/batch.py @@ -0,0 +1,119 @@ +import contextlib +import functools + +from django.contrib.auth.models import User + + +class BatchController(object): + ''' Batches are sets of operations where certain queries for users may be + repeated, but are also unlikely change within the boundaries of the batch. + + Batches are keyed per-user. You can mark the edge of the batch with the + ``batch`` context manager. If you nest calls to ``batch``, only the + outermost call will have the effect of ending the batch. + + Batches store results for functions wrapped with ``memoise``. These results + for the user are flushed at the end of the batch. + + If a return for a memoised function has a callable attribute called + ``end_batch``, that attribute will be called at the end of the batch. + + ''' + + _user_caches = {} + _NESTING_KEY = "nesting_count" + + @classmethod + @contextlib.contextmanager + def batch(cls, user): + ''' Marks the entry point for a batch for the given user. ''' + + cls._enter_batch_context(user) + try: + yield + finally: + # Make sure we clean up in case of errors. + cls._exit_batch_context(user) + + @classmethod + def _enter_batch_context(cls, user): + if user not in cls._user_caches: + cls._user_caches[user] = cls._new_cache() + + cache = cls._user_caches[user] + cache[cls._NESTING_KEY] += 1 + + @classmethod + def _exit_batch_context(cls, user): + cache = cls._user_caches[user] + cache[cls._NESTING_KEY] -= 1 + + if cache[cls._NESTING_KEY] == 0: + cls._call_end_batch_methods(user) + del cls._user_caches[user] + + @classmethod + def _call_end_batch_methods(cls, user): + cache = cls._user_caches[user] + ended = set() + while True: + keys = set(cache.keys()) + if ended == keys: + break + keys_to_end = keys - ended + for key in keys_to_end: + item = cache[key] + if hasattr(item, 'end_batch') and callable(item.end_batch): + item.end_batch() + ended = ended | keys_to_end + + @classmethod + def memoise(cls, func): + ''' Decorator that stores the result of the stored function in the + user's results cache until the batch completes. Keyword arguments are + not yet supported. + + Arguments: + func (callable(*a)): The function whose results we want + to store. The positional arguments, ``a``, are used as cache + keys. + + Returns: + callable(*a): The memosing version of ``func``. + + ''' + + @functools.wraps(func) + def f(*a): + + for arg in a: + if isinstance(arg, User): + user = arg + break + else: + raise ValueError("One position argument must be a User") + + func_key = (func, tuple(a)) + cache = cls.get_cache(user) + + if func_key not in cache: + cache[func_key] = func(*a) + + return cache[func_key] + + return f + + @classmethod + def get_cache(cls, user): + if user not in cls._user_caches: + # Return blank cache here, we'll just discard :) + return cls._new_cache() + + return cls._user_caches[user] + + @classmethod + def _new_cache(cls): + ''' Returns a new cache dictionary. ''' + cache = {} + cache[cls._NESTING_KEY] = 0 + return cache diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index dbf7e8a0..d0a9f057 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -16,6 +16,7 @@ from registrasion.models import commerce from registrasion.models import conditions from registrasion.models import inventory +from.batch import BatchController from .category import CategoryController from .discount import DiscountController from .flag import FlagController @@ -34,10 +35,11 @@ def _modifies_cart(func): def inner(self, *a, **k): self._fail_if_cart_is_not_active() with transaction.atomic(): - with CartController.operations_batch(self.cart.user) as mark: - mark.mark = True # Marker that we've modified the cart + with BatchController.batch(self.cart.user): + # Mark the version of self in the batch cache as modified + memoised = self.for_user(self.cart.user) + memoised._modified_by_batch = True return func(self, *a, **k) - return inner @@ -47,6 +49,7 @@ class CartController(object): self.cart = cart @classmethod + @BatchController.memoise def for_user(cls, user): ''' Returns the user's current cart, or creates a new cart if there isn't one ready yet. ''' @@ -64,59 +67,6 @@ class CartController(object): ) return cls(existing) - # Marks the carts that are currently in batches - _FOR_USER = {} - _BATCH_COUNT = collections.defaultdict(int) - _MODIFIED_CARTS = set() - - class _ModificationMarker(object): - pass - - @classmethod - @contextlib.contextmanager - def operations_batch(cls, user): - ''' Marks the boundary for a batch of operations on a user's cart. - - These markers can be nested. Only on exiting the outermost marker will - a batch be ended. - - When a batch is ended, discounts are recalculated, and the cart's - revision is increased. - ''' - - if user not in cls._FOR_USER: - _ctrl = cls.for_user(user) - cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id) - - ctrl, _id = cls._FOR_USER[user] - - cls._BATCH_COUNT[_id] += 1 - try: - success = False - - marker = cls._ModificationMarker() - yield marker - - if hasattr(marker, "mark"): - cls._MODIFIED_CARTS.add(_id) - - success = True - finally: - - cls._BATCH_COUNT[_id] -= 1 - - # Only end on the outermost batch marker, and only if - # it excited cleanly, and a modification occurred - modified = _id in cls._MODIFIED_CARTS - outermost = cls._BATCH_COUNT[_id] == 0 - if modified and outermost and success: - ctrl._end_batch() - cls._MODIFIED_CARTS.remove(_id) - - # Clear out the cache on the outermost operation - if outermost: - del cls._FOR_USER[user] - def _fail_if_cart_is_not_active(self): self.cart.refresh_from_db() if self.cart.status != commerce.Cart.STATUS_ACTIVE: @@ -144,6 +94,13 @@ class CartController(object): self.cart.time_last_updated = timezone.now() self.cart.reservation_duration = max(reservations) + + def end_batch(self): + ''' Calls ``_end_batch`` if a modification has been performed in the + previous batch. ''' + if hasattr(self,'_modified_by_batch'): + self._end_batch() + def _end_batch(self): ''' Performs operations that occur occur at the end of a batch of product changes/voucher applications etc. @@ -217,16 +174,14 @@ class CartController(object): errors = [] # Pre-annotate products - products = [p for (p, q) in product_quantities] - r = ProductController.attach_user_remainders(self.cart.user, products) - with_remainders = dict((p, p) for p in r) + remainders = ProductController.user_remainders(self.cart.user) # Test each product limit here for product, quantity in product_quantities: if quantity < 0: errors.append((product, "Value must be zero or greater.")) - limit = with_remainders[product].remainder + limit = remainders[product.id] if quantity > limit: errors.append(( @@ -242,12 +197,11 @@ class CartController(object): by_cat[product.category].append((product, quantity)) # Pre-annotate categories - r = CategoryController.attach_user_remainders(self.cart.user, by_cat) - with_remainders = dict((cat, cat) for cat in r) + remainders = CategoryController.user_remainders(self.cart.user) # Test each category limit here for category in by_cat: - limit = with_remainders[category].remainder + limit = remainders[category.id] # Get the amount so far in the cart to_add = sum(i[1] for i in by_cat[category]) diff --git a/registrasion/controllers/category.py b/registrasion/controllers/category.py index 9db8ca9e..4adf09b6 100644 --- a/registrasion/controllers/category.py +++ b/registrasion/controllers/category.py @@ -7,6 +7,7 @@ from django.db.models import Sum from django.db.models import When from django.db.models import Value +from .batch import BatchController class AllProducts(object): pass @@ -39,17 +40,17 @@ class CategoryController(object): return set(i.category for i in available) @classmethod - def attach_user_remainders(cls, user, categories): + @BatchController.memoise + def user_remainders(cls, user): ''' Return: - queryset(inventory.Product): A queryset containing items from - ``categories``, with an extra attribute -- remainder = the amount - of items from this category that is remaining. + Mapping[int->int]: A dictionary that maps the category ID to the + user's remainder for that category. + ''' - ids = [category.id for category in categories] - categories = inventory.Category.objects.filter(id__in=ids) + categories = inventory.Category.objects.all() cart_filter = ( Q(product__productitem__cart__user=user) & @@ -73,12 +74,4 @@ class CategoryController(object): categories = categories.annotate(remainder=remainder) - return categories - - def user_quantity_remaining(self, user): - ''' Returns the quantity of this product that the user add in the - current cart. ''' - - with_remainders = self.attach_user_remainders(user, [self.category]) - - return with_remainders[0].remainder + return dict((cat.id, cat.remainder) for cat in categories) diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index f4f88ed2..984fe214 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -1,6 +1,8 @@ import itertools -from conditions import ConditionController +from .batch import BatchController +from .conditions import ConditionController + from registrasion.models import commerce from registrasion.models import conditions @@ -10,7 +12,6 @@ from django.db.models import Sum from django.db.models import Value from django.db.models import When - class DiscountAndQuantity(object): ''' Represents a discount that can be applied to a product or category for a given user. @@ -50,7 +51,22 @@ class DiscountController(object): categories and products. The discounts also list the available quantity for this user, not including products that are pending purchase. ''' - filtered_clauses = cls._filtered_discounts(user, categories, products) + filtered_clauses = cls._filtered_clauses(user) + + # clauses that match provided categories + categories = set(categories) + # clauses that match provided products + products = set(products) + # clauses that match categories for provided products + product_categories = set(product.category for product in products) + # (Not relevant: clauses that match products in provided categories) + all_categories = categories | product_categories + + filtered_clauses = ( + clause for clause in filtered_clauses + if hasattr(clause, 'product') and clause.product in products or + hasattr(clause, 'category') and clause.category in all_categories + ) discounts = [] @@ -84,12 +100,13 @@ class DiscountController(object): return discounts @classmethod - def _filtered_discounts(cls, user, categories, products): + @BatchController.memoise + def _filtered_clauses(cls, user): ''' Returns: - Sequence[discountbase]: All discounts that passed the filter - function. + Sequence[DiscountForProduct | DiscountForCategory]: All clauses + that passed the filter function. ''' @@ -98,42 +115,22 @@ class DiscountController(object): i for i in types if issubclass(i, conditions.DiscountBase) ] - # discounts that match provided categories - category_discounts = conditions.DiscountForCategory.objects.filter( - category__in=categories - ) - # discounts that match provided products - product_discounts = conditions.DiscountForProduct.objects.filter( - product__in=products - ) - # discounts that match categories for provided products - product_category_discounts = conditions.DiscountForCategory.objects - product_category_discounts = product_category_discounts.filter( - category__in=(product.category for product in products) - ) - # (Not relevant: discounts that match products in provided categories) - - product_discounts = product_discounts.select_related( + product_clauses = conditions.DiscountForProduct.objects.all() + product_clauses = product_clauses.select_related( + "discount", "product", "product__category", ) - - all_category_discounts = ( - category_discounts | product_category_discounts - ) - all_category_discounts = all_category_discounts.select_related( + category_clauses = conditions.DiscountForCategory.objects.all() + category_clauses = category_clauses.select_related( "category", - ) - - valid_discounts = conditions.DiscountBase.objects.filter( - Q(discountforproduct__in=product_discounts) | - Q(discountforcategory__in=all_category_discounts) + "discount", ) all_subsets = [] for discounttype in discounttypes: - discounts = discounttype.objects.filter(id__in=valid_discounts) + discounts = discounttype.objects.all() ctrl = ConditionController.for_type(discounttype) discounts = ctrl.pre_filter(discounts, user) all_subsets.append(discounts) @@ -145,8 +142,8 @@ class DiscountController(object): from_filter = dict((i.id, i) for i in filtered_discounts) clause_sets = ( - product_discounts.filter(discount__in=filtered_discounts), - all_category_discounts.filter(discount__in=filtered_discounts), + product_clauses.filter(discount__in=filtered_discounts), + category_clauses.filter(discount__in=filtered_discounts), ) clause_sets = ( diff --git a/registrasion/controllers/flag.py b/registrasion/controllers/flag.py index aa11d53e..77d6476d 100644 --- a/registrasion/controllers/flag.py +++ b/registrasion/controllers/flag.py @@ -6,6 +6,7 @@ from collections import namedtuple from django.db.models import Count from django.db.models import Q +from .batch import BatchController from .conditions import ConditionController from registrasion.models import conditions @@ -47,8 +48,6 @@ class FlagController(object): a list is returned containing all of the products that are *not enabled*. ''' - print "GREPME: test_flags()" - if products is not None and product_quantities is not None: raise ValueError("Please specify only products or " "product_quantities") @@ -62,7 +61,7 @@ class FlagController(object): if products: # Simplify the query. - all_conditions = cls._filtered_flags(user, products) + all_conditions = cls._filtered_flags(user) else: all_conditions = [] @@ -86,6 +85,8 @@ class FlagController(object): # from the categories covered by this condition ids = [product.id for product in products] + + # TODO: This is re-evaluated a lot. all_products = inventory.Product.objects.filter(id__in=ids) cond = ( Q(flagbase_set=condition) | @@ -117,7 +118,7 @@ class FlagController(object): if not met and product not in messages: messages[product] = message - total_flags = FlagCounter.count() + total_flags = FlagCounter.count(user) valid = {} @@ -160,7 +161,8 @@ class FlagController(object): return error_fields @classmethod - def _filtered_flags(cls, user, products): + @BatchController.memoise + def _filtered_flags(cls, user): ''' Returns: @@ -171,26 +173,15 @@ class FlagController(object): types = list(ConditionController._controllers()) flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] - # Get all flags for the products and categories. - prods = ( - product.flagbase_set.all() - for product in products - ) - cats = ( - category.flagbase_set.all() - for category in set(product.category for product in products) - ) - all_flags = reduce(operator.or_, itertools.chain(prods, cats)) - all_subsets = [] for flagtype in flagtypes: - flags = flagtype.objects.filter(id__in=all_flags) + flags = flagtype.objects.all() ctrl = ConditionController.for_type(flagtype) flags = ctrl.pre_filter(flags, user) all_subsets.append(flags) - return itertools.chain(*all_subsets) + return list(itertools.chain(*all_subsets)) ConditionAndRemainder = namedtuple( @@ -220,11 +211,11 @@ _ConditionsCount = namedtuple( ) -# TODO: this should be cacheable. class FlagCounter(_FlagCounter): @classmethod - def count(cls): + @BatchController.memoise + def count(cls, user): # Get the count of how many conditions should exist per product flagbases = conditions.FlagBase.objects diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 610c7f0d..4210bd7c 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -9,6 +9,7 @@ from django.db.models import Value from registrasion.models import commerce from registrasion.models import inventory +from .batch import BatchController from .category import CategoryController from .flag import FlagController @@ -34,18 +35,14 @@ class ProductController(object): if products is not None: all_products = set(itertools.chain(all_products, products)) - categories = set(product.category for product in all_products) - r = CategoryController.attach_user_remainders(user, categories) - cat_quants = dict((c, c) for c in r) - - r = ProductController.attach_user_remainders(user, all_products) - prod_quants = dict((p, p) for p in r) + category_remainders = CategoryController.user_remainders(user) + product_remainders = ProductController.user_remainders(user) passed_limits = set( product for product in all_products - if cat_quants[product.category].remainder > 0 - if prod_quants[product].remainder > 0 + if category_remainders[product.category.id] > 0 + if product_remainders[product.id] > 0 ) failed_and_messages = FlagController.test_flags( @@ -59,17 +56,16 @@ class ProductController(object): return out @classmethod - def attach_user_remainders(cls, user, products): + @BatchController.memoise + def user_remainders(cls, user): ''' Return: - queryset(inventory.Product): A queryset containing items from - ``product``, with an extra attribute -- remainder = the amount of - this item that is remaining. + Mapping[int->int]: A dictionary that maps the product ID to the + user's remainder for that product. ''' - ids = [product.id for product in products] - products = inventory.Product.objects.filter(id__in=ids) + products = inventory.Product.objects.all() cart_filter = ( Q(productitem__cart__user=user) & @@ -93,12 +89,4 @@ class ProductController(object): products = products.annotate(remainder=remainder) - return products - - def user_quantity_remaining(self, user): - ''' Returns the quantity of this product that the user add in the - current cart. ''' - - with_remainders = self.attach_user_remainders(user, [self.product]) - - return with_remainders[0].remainder + return dict((product.id, product.remainder) for product in products) diff --git a/registrasion/tests/test_batch.py b/registrasion/tests/test_batch.py new file mode 100644 index 00000000..70370799 --- /dev/null +++ b/registrasion/tests/test_batch.py @@ -0,0 +1,144 @@ +import datetime +import pytz + +from django.core.exceptions import ValidationError + +from controller_helpers import TestingCartController +from test_cart import RegistrationCartTestCase + +from registrasion.controllers.batch import BatchController +from registrasion.controllers.discount import DiscountController +from registrasion.controllers.product import ProductController +from registrasion.models import commerce +from registrasion.models import conditions + +UTC = pytz.timezone('UTC') + + +class BatchTestCase(RegistrationCartTestCase): + + def test_no_caches_outside_of_batches(self): + cache_1 = BatchController.get_cache(self.USER_1) + cache_2 = BatchController.get_cache(self.USER_2) + + # Identity testing is important here + self.assertIsNot(cache_1, cache_2) + + def test_cache_clears_at_batch_exit(self): + with BatchController.batch(self.USER_1): + cache_1 = BatchController.get_cache(self.USER_1) + + cache_2 = BatchController.get_cache(self.USER_1) + + self.assertIsNot(cache_1, cache_2) + + def test_caches_identical_within_nestings(self): + with BatchController.batch(self.USER_1): + cache_1 = BatchController.get_cache(self.USER_1) + + with BatchController.batch(self.USER_2): + cache_2 = BatchController.get_cache(self.USER_1) + + cache_3 = BatchController.get_cache(self.USER_1) + + self.assertIs(cache_1, cache_2) + self.assertIs(cache_2, cache_3) + + def test_caches_are_independent_for_different_users(self): + with BatchController.batch(self.USER_1): + cache_1 = BatchController.get_cache(self.USER_1) + + with BatchController.batch(self.USER_2): + cache_2 = BatchController.get_cache(self.USER_2) + + self.assertIsNot(cache_1, cache_2) + + def test_cache_clears_are_independent_for_different_users(self): + with BatchController.batch(self.USER_1): + cache_1 = BatchController.get_cache(self.USER_1) + + with BatchController.batch(self.USER_2): + cache_2 = BatchController.get_cache(self.USER_2) + + with BatchController.batch(self.USER_2): + cache_3 = BatchController.get_cache(self.USER_2) + + cache_4 = BatchController.get_cache(self.USER_1) + + self.assertIs(cache_1, cache_4) + self.assertIsNot(cache_1, cache_2) + self.assertIsNot(cache_2, cache_3) + + def test_new_caches_for_new_batches(self): + with BatchController.batch(self.USER_1): + cache_1 = BatchController.get_cache(self.USER_1) + + with BatchController.batch(self.USER_1): + cache_2 = BatchController.get_cache(self.USER_1) + + with BatchController.batch(self.USER_1): + cache_3 = BatchController.get_cache(self.USER_1) + + self.assertIs(cache_2, cache_3) + self.assertIsNot(cache_1, cache_2) + + def test_memoisation_happens_in_batch_context(self): + with BatchController.batch(self.USER_1): + output_1 = self._memoiseme(self.USER_1) + + with BatchController.batch(self.USER_1): + output_2 = self._memoiseme(self.USER_1) + + self.assertIs(output_1, output_2) + + def test_memoisaion_does_not_happen_outside_batch_context(self): + output_1 = self._memoiseme(self.USER_1) + output_2 = self._memoiseme(self.USER_1) + + self.assertIsNot(output_1, output_2) + + def test_memoisation_is_user_independent(self): + with BatchController.batch(self.USER_1): + output_1 = self._memoiseme(self.USER_1) + with BatchController.batch(self.USER_2): + output_2 = self._memoiseme(self.USER_2) + output_3 = self._memoiseme(self.USER_1) + + self.assertIsNot(output_1, output_2) + self.assertIs(output_1, output_3) + + def test_memoisation_clears_outside_batches(self): + with BatchController.batch(self.USER_1): + output_1 = self._memoiseme(self.USER_1) + + with BatchController.batch(self.USER_1): + output_2 = self._memoiseme(self.USER_1) + + self.assertIsNot(output_1, output_2) + + @classmethod + @BatchController.memoise + def _memoiseme(self, user): + return object() + + def test_batch_end_functionality_is_called(self): + class Ender(object): + end_count = 0 + def end_batch(self): + self.end_count += 1 + + @BatchController.memoise + def get_ender(user): + return Ender() + + # end_batch should get called once on exiting the batch + with BatchController.batch(self.USER_1): + ender = get_ender(self.USER_1) + self.assertEquals(1, ender.end_count) + + # end_batch should get called once on exiting the batch + # no matter how deep the object gets cached + with BatchController.batch(self.USER_1): + with BatchController.batch(self.USER_1): + ender = get_ender(self.USER_1) + self.assertEquals(1, ender.end_count) diff --git a/registrasion/tests/test_cart.py b/registrasion/tests/test_cart.py index 790c1df9..619b9074 100644 --- a/registrasion/tests/test_cart.py +++ b/registrasion/tests/test_cart.py @@ -12,6 +12,7 @@ from registrasion.models import commerce from registrasion.models import conditions from registrasion.models import inventory from registrasion.models import people +from registrasion.controllers.batch import BatchController from registrasion.controllers.product import ProductController from controller_helpers import TestingCartController @@ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase): def test_available_products_respects_product_limits(self): self.__available_products_test(self.PROD_4, 6) + + def test_cart_controller_for_user_is_memoised(self): + # - that for_user is memoised + with BatchController.batch(self.USER_1): + cart = TestingCartController.for_user(self.USER_1) + cart_2 = TestingCartController.for_user(self.USER_1) + self.assertIs(cart, cart_2) + + def test_cart_revision_does_not_increment_if_not_modified(self): + cart = TestingCartController.for_user(self.USER_1) + rev_0 = cart.cart.revision + + with BatchController.batch(self.USER_1): + # Memoise the cart + same_cart = TestingCartController.for_user(self.USER_1) + # Do nothing on exit + + rev_1 = self.reget(cart.cart).revision + self.assertEqual(rev_0, rev_1) + + def test_cart_revision_only_increments_at_end_of_batches(self): + cart = TestingCartController.for_user(self.USER_1) + rev_0 = cart.cart.revision + + with BatchController.batch(self.USER_1): + # Memoise the cart + same_cart = TestingCartController.for_user(self.USER_1) + same_cart.add_to_cart(self.PROD_1, 1) + rev_1 = self.reget(same_cart.cart).revision + + rev_2 = self.reget(cart.cart).revision + + self.assertEqual(rev_0, rev_1) + self.assertNotEqual(rev_0, rev_2) + + def test_cart_discounts_only_calculated_at_end_of_batches(self): + def count_discounts(cart): + return cart.cart.discountitem_set.count() + + cart = TestingCartController.for_user(self.USER_1) + self.make_discount_ceiling("FLOOZLE") + count_0 = count_discounts(cart) + + with BatchController.batch(self.USER_1): + # Memoise the cart + same_cart = TestingCartController.for_user(self.USER_1) + + with BatchController.batch(self.USER_1): + # Memoise the cart + same_cart_2 = TestingCartController.for_user(self.USER_1) + + same_cart_2.add_to_cart(self.PROD_1, 1) + count_1 = count_discounts(same_cart_2) + + count_2 = count_discounts(same_cart) + + count_3 = count_discounts(cart) + + self.assertEqual(0, count_0) + self.assertEqual(0, count_1) + self.assertEqual(0, count_2) + self.assertEqual(1, count_3) diff --git a/registrasion/views.py b/registrasion/views.py index a4dcceac..3d2a3c04 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -5,9 +5,10 @@ from registrasion import util from registrasion.models import commerce from registrasion.models import inventory from registrasion.models import people -from registrasion.controllers.discount import DiscountController +from registrasion.controllers.batch import BatchController from registrasion.controllers.cart import CartController from registrasion.controllers.credit_note import CreditNoteController +from registrasion.controllers.discount import DiscountController from registrasion.controllers.invoice import InvoiceController from registrasion.controllers.product import ProductController from registrasion.exceptions import CartValidationError @@ -170,18 +171,18 @@ def guided_registration(request): category__in=cats, ).select_related("category") - available_products = set(ProductController.available_products( - request.user, - products=all_products, - )) + with BatchController.batch(request.user): + available_products = set(ProductController.available_products( + request.user, + products=all_products, + )) - if len(available_products) == 0: - # We've filled in every category - attendee.completed_registration = True - attendee.save() - return next_step + if len(available_products) == 0: + # We've filled in every category + attendee.completed_registration = True + attendee.save() + return next_step - with CartController.operations_batch(request.user): for category in cats: products = [ i for i in available_products @@ -345,20 +346,21 @@ def product_category(request, category_id): category_id = int(category_id) # Routing is [0-9]+ category = inventory.Category.objects.get(pk=category_id) - products = ProductController.available_products( - request.user, - category=category, - ) - - if not products: - messages.warning( - request, - "There are no products available from category: " + category.name, + with BatchController.batch(request.user): + products = ProductController.available_products( + request.user, + category=category, ) - return redirect("dashboard") - p = _handle_products(request, category, products, PRODUCTS_FORM_PREFIX) - products_form, discounts, products_handled = p + if not products: + messages.warning( + request, + "There are no products available from category: " + category.name, + ) + return redirect("dashboard") + + p = _handle_products(request, category, products, PRODUCTS_FORM_PREFIX) + products_form, discounts, products_handled = p if request.POST and not voucher_handled and not products_form.errors: # Only return to the dashboard if we didn't add a voucher code