Merge branch 'batch_cache'
This commit is contained in:
		
						commit
						ded5114073
					
				
					 9 changed files with 430 additions and 179 deletions
				
			
		
							
								
								
									
										119
									
								
								registrasion/controllers/batch.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								registrasion/controllers/batch.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,119 @@ | |||
| import contextlib | ||||
| import functools | ||||
| 
 | ||||
| from django.contrib.auth.models import User | ||||
| 
 | ||||
| 
 | ||||
| class BatchController(object): | ||||
|     ''' Batches are sets of operations where certain queries for users may be | ||||
|     repeated, but are also unlikely change within the boundaries of the batch. | ||||
| 
 | ||||
|     Batches are keyed per-user. You can mark the edge of the batch with the | ||||
|     ``batch`` context manager. If you nest calls to ``batch``, only the | ||||
|     outermost call will have the effect of ending the batch. | ||||
| 
 | ||||
|     Batches store results for functions wrapped with ``memoise``. These results | ||||
|     for the user are flushed at the end of the batch. | ||||
| 
 | ||||
|     If a return for a memoised function has a callable attribute called | ||||
|     ``end_batch``, that attribute will be called at the end of the batch. | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     _user_caches = {} | ||||
|     _NESTING_KEY = "nesting_count" | ||||
| 
 | ||||
|     @classmethod | ||||
|     @contextlib.contextmanager | ||||
|     def batch(cls, user): | ||||
|         ''' Marks the entry point for a batch for the given user. ''' | ||||
| 
 | ||||
|         cls._enter_batch_context(user) | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             # Make sure we clean up in case of errors. | ||||
|             cls._exit_batch_context(user) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _enter_batch_context(cls, user): | ||||
|         if user not in cls._user_caches: | ||||
|             cls._user_caches[user] = cls._new_cache() | ||||
| 
 | ||||
|         cache = cls._user_caches[user] | ||||
|         cache[cls._NESTING_KEY] += 1 | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _exit_batch_context(cls, user): | ||||
|         cache = cls._user_caches[user] | ||||
|         cache[cls._NESTING_KEY] -= 1 | ||||
| 
 | ||||
|         if cache[cls._NESTING_KEY] == 0: | ||||
|             cls._call_end_batch_methods(user) | ||||
|             del cls._user_caches[user] | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _call_end_batch_methods(cls, user): | ||||
|         cache = cls._user_caches[user] | ||||
|         ended = set() | ||||
|         while True: | ||||
|             keys = set(cache.keys()) | ||||
|             if ended == keys: | ||||
|                 break | ||||
|             keys_to_end = keys - ended | ||||
|             for key in keys_to_end: | ||||
|                 item = cache[key] | ||||
|                 if hasattr(item, 'end_batch') and callable(item.end_batch): | ||||
|                     item.end_batch() | ||||
|             ended = ended | keys_to_end | ||||
| 
 | ||||
|     @classmethod | ||||
|     def memoise(cls, func): | ||||
|         ''' Decorator that stores the result of the stored function in the | ||||
|         user's results cache until the batch completes. Keyword arguments are | ||||
|         not yet supported. | ||||
| 
 | ||||
|         Arguments: | ||||
|             func (callable(*a)): The function whose results we want | ||||
|                 to store. The positional arguments, ``a``, are used as cache | ||||
|                 keys. | ||||
| 
 | ||||
|         Returns: | ||||
|             callable(*a): The memosing version of ``func``. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         @functools.wraps(func) | ||||
|         def f(*a): | ||||
| 
 | ||||
|             for arg in a: | ||||
|                 if isinstance(arg, User): | ||||
|                     user = arg | ||||
|                     break | ||||
|             else: | ||||
|                 raise ValueError("One position argument must be a User") | ||||
| 
 | ||||
|             func_key = (func, tuple(a)) | ||||
|             cache = cls.get_cache(user) | ||||
| 
 | ||||
|             if func_key not in cache: | ||||
|                 cache[func_key] = func(*a) | ||||
| 
 | ||||
|             return cache[func_key] | ||||
| 
 | ||||
|         return f | ||||
| 
 | ||||
|     @classmethod | ||||
|     def get_cache(cls, user): | ||||
|         if user not in cls._user_caches: | ||||
|             # Return blank cache here, we'll just discard :) | ||||
|             return cls._new_cache() | ||||
| 
 | ||||
|         return cls._user_caches[user] | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _new_cache(cls): | ||||
|         ''' Returns a new cache dictionary. ''' | ||||
|         cache = {} | ||||
|         cache[cls._NESTING_KEY] = 0 | ||||
|         return cache | ||||
|  | @ -16,6 +16,7 @@ from registrasion.models import commerce | |||
| from registrasion.models import conditions | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| from.batch import BatchController | ||||
| from .category import CategoryController | ||||
| from .discount import DiscountController | ||||
| from .flag import FlagController | ||||
|  | @ -34,10 +35,11 @@ def _modifies_cart(func): | |||
|     def inner(self, *a, **k): | ||||
|         self._fail_if_cart_is_not_active() | ||||
|         with transaction.atomic(): | ||||
|             with CartController.operations_batch(self.cart.user) as mark: | ||||
|                 mark.mark = True  # Marker that we've modified the cart | ||||
|             with BatchController.batch(self.cart.user): | ||||
|                 # Mark the version of self in the batch cache as modified | ||||
|                 memoised = self.for_user(self.cart.user) | ||||
|                 memoised._modified_by_batch = True | ||||
|                 return func(self, *a, **k) | ||||
| 
 | ||||
|     return inner | ||||
| 
 | ||||
| 
 | ||||
|  | @ -47,6 +49,7 @@ class CartController(object): | |||
|         self.cart = cart | ||||
| 
 | ||||
|     @classmethod | ||||
|     @BatchController.memoise | ||||
|     def for_user(cls, user): | ||||
|         ''' Returns the user's current cart, or creates a new cart | ||||
|         if there isn't one ready yet. ''' | ||||
|  | @ -64,59 +67,6 @@ class CartController(object): | |||
|             ) | ||||
|         return cls(existing) | ||||
| 
 | ||||
|     # Marks the carts that are currently in batches | ||||
|     _FOR_USER = {} | ||||
|     _BATCH_COUNT = collections.defaultdict(int) | ||||
|     _MODIFIED_CARTS = set() | ||||
| 
 | ||||
|     class _ModificationMarker(object): | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     @contextlib.contextmanager | ||||
|     def operations_batch(cls, user): | ||||
|         ''' Marks the boundary for a batch of operations on a user's cart. | ||||
| 
 | ||||
|         These markers can be nested. Only on exiting the outermost marker will | ||||
|         a batch be ended. | ||||
| 
 | ||||
|         When a batch is ended, discounts are recalculated, and the cart's | ||||
|         revision is increased. | ||||
|         ''' | ||||
| 
 | ||||
|         if user not in cls._FOR_USER: | ||||
|             _ctrl = cls.for_user(user) | ||||
|             cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id) | ||||
| 
 | ||||
|         ctrl, _id = cls._FOR_USER[user] | ||||
| 
 | ||||
|         cls._BATCH_COUNT[_id] += 1 | ||||
|         try: | ||||
|             success = False | ||||
| 
 | ||||
|             marker = cls._ModificationMarker() | ||||
|             yield marker | ||||
| 
 | ||||
|             if hasattr(marker, "mark"): | ||||
|                 cls._MODIFIED_CARTS.add(_id) | ||||
| 
 | ||||
|             success = True | ||||
|         finally: | ||||
| 
 | ||||
|             cls._BATCH_COUNT[_id] -= 1 | ||||
| 
 | ||||
|             # Only end on the outermost batch marker, and only if | ||||
|             # it excited cleanly, and a modification occurred | ||||
|             modified = _id in cls._MODIFIED_CARTS | ||||
|             outermost = cls._BATCH_COUNT[_id] == 0 | ||||
|             if modified and outermost and success: | ||||
|                 ctrl._end_batch() | ||||
|                 cls._MODIFIED_CARTS.remove(_id) | ||||
| 
 | ||||
|             # Clear out the cache on the outermost operation | ||||
|             if outermost: | ||||
|                 del cls._FOR_USER[user] | ||||
| 
 | ||||
|     def _fail_if_cart_is_not_active(self): | ||||
|         self.cart.refresh_from_db() | ||||
|         if self.cart.status != commerce.Cart.STATUS_ACTIVE: | ||||
|  | @ -144,6 +94,13 @@ class CartController(object): | |||
|         self.cart.time_last_updated = timezone.now() | ||||
|         self.cart.reservation_duration = max(reservations) | ||||
| 
 | ||||
| 
 | ||||
|     def end_batch(self): | ||||
|         ''' Calls ``_end_batch`` if a modification has been performed in the | ||||
|         previous batch. ''' | ||||
|         if hasattr(self,'_modified_by_batch'): | ||||
|             self._end_batch() | ||||
| 
 | ||||
|     def _end_batch(self): | ||||
|         ''' Performs operations that occur occur at the end of a batch of | ||||
|         product changes/voucher applications etc. | ||||
|  | @ -217,16 +174,14 @@ class CartController(object): | |||
|         errors = [] | ||||
| 
 | ||||
|         # Pre-annotate products | ||||
|         products = [p for (p, q) in product_quantities] | ||||
|         r = ProductController.attach_user_remainders(self.cart.user, products) | ||||
|         with_remainders = dict((p, p) for p in r) | ||||
|         remainders = ProductController.user_remainders(self.cart.user) | ||||
| 
 | ||||
|         # Test each product limit here | ||||
|         for product, quantity in product_quantities: | ||||
|             if quantity < 0: | ||||
|                 errors.append((product, "Value must be zero or greater.")) | ||||
| 
 | ||||
|             limit = with_remainders[product].remainder | ||||
|             limit = remainders[product.id] | ||||
| 
 | ||||
|             if quantity > limit: | ||||
|                 errors.append(( | ||||
|  | @ -242,12 +197,11 @@ class CartController(object): | |||
|             by_cat[product.category].append((product, quantity)) | ||||
| 
 | ||||
|         # Pre-annotate categories | ||||
|         r = CategoryController.attach_user_remainders(self.cart.user, by_cat) | ||||
|         with_remainders = dict((cat, cat) for cat in r) | ||||
|         remainders = CategoryController.user_remainders(self.cart.user) | ||||
| 
 | ||||
|         # Test each category limit here | ||||
|         for category in by_cat: | ||||
|             limit = with_remainders[category].remainder | ||||
|             limit = remainders[category.id] | ||||
| 
 | ||||
|             # Get the amount so far in the cart | ||||
|             to_add = sum(i[1] for i in by_cat[category]) | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from django.db.models import Sum | |||
| from django.db.models import When | ||||
| from django.db.models import Value | ||||
| 
 | ||||
| from .batch import BatchController | ||||
| 
 | ||||
| class AllProducts(object): | ||||
|     pass | ||||
|  | @ -39,17 +40,17 @@ class CategoryController(object): | |||
|         return set(i.category for i in available) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def attach_user_remainders(cls, user, categories): | ||||
|     @BatchController.memoise | ||||
|     def user_remainders(cls, user): | ||||
|         ''' | ||||
| 
 | ||||
|         Return: | ||||
|             queryset(inventory.Product): A queryset containing items from | ||||
|             ``categories``, with an extra attribute -- remainder = the amount | ||||
|             of items from this category that is remaining. | ||||
|             Mapping[int->int]: A dictionary that maps the category ID to the | ||||
|             user's remainder for that category. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         ids = [category.id for category in categories] | ||||
|         categories = inventory.Category.objects.filter(id__in=ids) | ||||
|         categories = inventory.Category.objects.all() | ||||
| 
 | ||||
|         cart_filter = ( | ||||
|             Q(product__productitem__cart__user=user) & | ||||
|  | @ -73,12 +74,4 @@ class CategoryController(object): | |||
| 
 | ||||
|         categories = categories.annotate(remainder=remainder) | ||||
| 
 | ||||
|         return categories | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' Returns the quantity of this product that the user add in the | ||||
|         current cart. ''' | ||||
| 
 | ||||
|         with_remainders = self.attach_user_remainders(user, [self.category]) | ||||
| 
 | ||||
|         return with_remainders[0].remainder | ||||
|         return dict((cat.id, cat.remainder) for cat in categories) | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| import itertools | ||||
| 
 | ||||
| from conditions import ConditionController | ||||
| from .batch import BatchController | ||||
| from .conditions import ConditionController | ||||
| 
 | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
|  | @ -10,7 +12,6 @@ from django.db.models import Sum | |||
| from django.db.models import Value | ||||
| from django.db.models import When | ||||
| 
 | ||||
| 
 | ||||
| class DiscountAndQuantity(object): | ||||
|     ''' Represents a discount that can be applied to a product or category | ||||
|     for a given user. | ||||
|  | @ -50,7 +51,22 @@ class DiscountController(object): | |||
|         categories and products. The discounts also list the available quantity | ||||
|         for this user, not including products that are pending purchase. ''' | ||||
| 
 | ||||
|         filtered_clauses = cls._filtered_discounts(user, categories, products) | ||||
|         filtered_clauses = cls._filtered_clauses(user) | ||||
| 
 | ||||
|         # clauses that match provided categories | ||||
|         categories = set(categories) | ||||
|         # clauses that match provided products | ||||
|         products = set(products) | ||||
|         # clauses that match categories for provided products | ||||
|         product_categories = set(product.category for product in products) | ||||
|         # (Not relevant: clauses that match products in provided categories) | ||||
|         all_categories = categories | product_categories | ||||
| 
 | ||||
|         filtered_clauses = ( | ||||
|             clause for clause in filtered_clauses | ||||
|             if hasattr(clause, 'product') and clause.product in products or | ||||
|             hasattr(clause, 'category') and clause.category in all_categories | ||||
|         ) | ||||
| 
 | ||||
|         discounts = [] | ||||
| 
 | ||||
|  | @ -84,12 +100,13 @@ class DiscountController(object): | |||
|         return discounts | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _filtered_discounts(cls, user, categories, products): | ||||
|     @BatchController.memoise | ||||
|     def _filtered_clauses(cls, user): | ||||
|         ''' | ||||
| 
 | ||||
|         Returns: | ||||
|             Sequence[discountbase]: All discounts that passed the filter | ||||
|             function. | ||||
|             Sequence[DiscountForProduct | DiscountForCategory]: All clauses | ||||
|             that passed the filter function. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|  | @ -98,42 +115,22 @@ class DiscountController(object): | |||
|             i for i in types if issubclass(i, conditions.DiscountBase) | ||||
|         ] | ||||
| 
 | ||||
|         # discounts that match provided categories | ||||
|         category_discounts = conditions.DiscountForCategory.objects.filter( | ||||
|             category__in=categories | ||||
|         ) | ||||
|         # discounts that match provided products | ||||
|         product_discounts = conditions.DiscountForProduct.objects.filter( | ||||
|             product__in=products | ||||
|         ) | ||||
|         # discounts that match categories for provided products | ||||
|         product_category_discounts = conditions.DiscountForCategory.objects | ||||
|         product_category_discounts = product_category_discounts.filter( | ||||
|             category__in=(product.category for product in products) | ||||
|         ) | ||||
|         # (Not relevant: discounts that match products in provided categories) | ||||
| 
 | ||||
|         product_discounts = product_discounts.select_related( | ||||
|         product_clauses = conditions.DiscountForProduct.objects.all() | ||||
|         product_clauses = product_clauses.select_related( | ||||
|             "discount", | ||||
|             "product", | ||||
|             "product__category", | ||||
|         ) | ||||
| 
 | ||||
|         all_category_discounts = ( | ||||
|             category_discounts | product_category_discounts | ||||
|         ) | ||||
|         all_category_discounts = all_category_discounts.select_related( | ||||
|         category_clauses = conditions.DiscountForCategory.objects.all() | ||||
|         category_clauses = category_clauses.select_related( | ||||
|             "category", | ||||
|         ) | ||||
| 
 | ||||
|         valid_discounts = conditions.DiscountBase.objects.filter( | ||||
|             Q(discountforproduct__in=product_discounts) | | ||||
|             Q(discountforcategory__in=all_category_discounts) | ||||
|             "discount", | ||||
|         ) | ||||
| 
 | ||||
|         all_subsets = [] | ||||
| 
 | ||||
|         for discounttype in discounttypes: | ||||
|             discounts = discounttype.objects.filter(id__in=valid_discounts) | ||||
|             discounts = discounttype.objects.all() | ||||
|             ctrl = ConditionController.for_type(discounttype) | ||||
|             discounts = ctrl.pre_filter(discounts, user) | ||||
|             all_subsets.append(discounts) | ||||
|  | @ -145,8 +142,8 @@ class DiscountController(object): | |||
|         from_filter = dict((i.id, i) for i in filtered_discounts) | ||||
| 
 | ||||
|         clause_sets = ( | ||||
|             product_discounts.filter(discount__in=filtered_discounts), | ||||
|             all_category_discounts.filter(discount__in=filtered_discounts), | ||||
|             product_clauses.filter(discount__in=filtered_discounts), | ||||
|             category_clauses.filter(discount__in=filtered_discounts), | ||||
|         ) | ||||
| 
 | ||||
|         clause_sets = ( | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ from collections import namedtuple | |||
| from django.db.models import Count | ||||
| from django.db.models import Q | ||||
| 
 | ||||
| from .batch import BatchController | ||||
| from .conditions import ConditionController | ||||
| 
 | ||||
| from registrasion.models import conditions | ||||
|  | @ -47,8 +48,6 @@ class FlagController(object): | |||
|         a list is returned containing all of the products that are *not | ||||
|         enabled*. ''' | ||||
| 
 | ||||
|         print "GREPME: test_flags()" | ||||
| 
 | ||||
|         if products is not None and product_quantities is not None: | ||||
|             raise ValueError("Please specify only products or " | ||||
|                              "product_quantities") | ||||
|  | @ -62,7 +61,7 @@ class FlagController(object): | |||
| 
 | ||||
|         if products: | ||||
|             # Simplify the query. | ||||
|             all_conditions = cls._filtered_flags(user, products) | ||||
|             all_conditions = cls._filtered_flags(user) | ||||
|         else: | ||||
|             all_conditions = [] | ||||
| 
 | ||||
|  | @ -86,6 +85,8 @@ class FlagController(object): | |||
|             # from the categories covered by this condition | ||||
| 
 | ||||
|             ids = [product.id for product in products] | ||||
| 
 | ||||
|             # TODO: This is re-evaluated a lot. | ||||
|             all_products = inventory.Product.objects.filter(id__in=ids) | ||||
|             cond = ( | ||||
|                 Q(flagbase_set=condition) | | ||||
|  | @ -117,7 +118,7 @@ class FlagController(object): | |||
|                 if not met and product not in messages: | ||||
|                     messages[product] = message | ||||
| 
 | ||||
|         total_flags = FlagCounter.count() | ||||
|         total_flags = FlagCounter.count(user) | ||||
| 
 | ||||
|         valid = {} | ||||
| 
 | ||||
|  | @ -160,7 +161,8 @@ class FlagController(object): | |||
|         return error_fields | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _filtered_flags(cls, user, products): | ||||
|     @BatchController.memoise | ||||
|     def _filtered_flags(cls, user): | ||||
|         ''' | ||||
| 
 | ||||
|         Returns: | ||||
|  | @ -171,26 +173,15 @@ class FlagController(object): | |||
|         types = list(ConditionController._controllers()) | ||||
|         flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] | ||||
| 
 | ||||
|         # Get all flags for the products and categories. | ||||
|         prods = ( | ||||
|             product.flagbase_set.all() | ||||
|             for product in products | ||||
|         ) | ||||
|         cats = ( | ||||
|             category.flagbase_set.all() | ||||
|             for category in set(product.category for product in products) | ||||
|         ) | ||||
|         all_flags = reduce(operator.or_, itertools.chain(prods, cats)) | ||||
| 
 | ||||
|         all_subsets = [] | ||||
| 
 | ||||
|         for flagtype in flagtypes: | ||||
|             flags = flagtype.objects.filter(id__in=all_flags) | ||||
|             flags = flagtype.objects.all() | ||||
|             ctrl = ConditionController.for_type(flagtype) | ||||
|             flags = ctrl.pre_filter(flags, user) | ||||
|             all_subsets.append(flags) | ||||
| 
 | ||||
|         return itertools.chain(*all_subsets) | ||||
|         return list(itertools.chain(*all_subsets)) | ||||
| 
 | ||||
| 
 | ||||
| ConditionAndRemainder = namedtuple( | ||||
|  | @ -220,11 +211,11 @@ _ConditionsCount = namedtuple( | |||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO: this should be cacheable. | ||||
| class FlagCounter(_FlagCounter): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def count(cls): | ||||
|     @BatchController.memoise | ||||
|     def count(cls, user): | ||||
|         # Get the count of how many conditions should exist per product | ||||
|         flagbases = conditions.FlagBase.objects | ||||
| 
 | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from django.db.models import Value | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| from .batch import BatchController | ||||
| from .category import CategoryController | ||||
| from .flag import FlagController | ||||
| 
 | ||||
|  | @ -34,18 +35,14 @@ class ProductController(object): | |||
|         if products is not None: | ||||
|             all_products = set(itertools.chain(all_products, products)) | ||||
| 
 | ||||
|         categories = set(product.category for product in all_products) | ||||
|         r = CategoryController.attach_user_remainders(user, categories) | ||||
|         cat_quants = dict((c, c) for c in r) | ||||
| 
 | ||||
|         r = ProductController.attach_user_remainders(user, all_products) | ||||
|         prod_quants = dict((p, p) for p in r) | ||||
|         category_remainders = CategoryController.user_remainders(user) | ||||
|         product_remainders = ProductController.user_remainders(user) | ||||
| 
 | ||||
|         passed_limits = set( | ||||
|             product | ||||
|             for product in all_products | ||||
|             if cat_quants[product.category].remainder > 0 | ||||
|             if prod_quants[product].remainder > 0 | ||||
|             if category_remainders[product.category.id] > 0 | ||||
|             if product_remainders[product.id] > 0 | ||||
|         ) | ||||
| 
 | ||||
|         failed_and_messages = FlagController.test_flags( | ||||
|  | @ -59,17 +56,16 @@ class ProductController(object): | |||
|         return out | ||||
| 
 | ||||
|     @classmethod | ||||
|     def attach_user_remainders(cls, user, products): | ||||
|     @BatchController.memoise | ||||
|     def user_remainders(cls, user): | ||||
|         ''' | ||||
| 
 | ||||
|         Return: | ||||
|             queryset(inventory.Product): A queryset containing items from | ||||
|             ``product``, with an extra attribute -- remainder = the amount of | ||||
|             this item that is remaining. | ||||
|             Mapping[int->int]: A dictionary that maps the product ID to the | ||||
|             user's remainder for that product. | ||||
|         ''' | ||||
| 
 | ||||
|         ids = [product.id for product in products] | ||||
|         products = inventory.Product.objects.filter(id__in=ids) | ||||
|         products = inventory.Product.objects.all() | ||||
| 
 | ||||
|         cart_filter = ( | ||||
|             Q(productitem__cart__user=user) & | ||||
|  | @ -93,12 +89,4 @@ class ProductController(object): | |||
| 
 | ||||
|         products = products.annotate(remainder=remainder) | ||||
| 
 | ||||
|         return products | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' Returns the quantity of this product that the user add in the | ||||
|         current cart. ''' | ||||
| 
 | ||||
|         with_remainders = self.attach_user_remainders(user, [self.product]) | ||||
| 
 | ||||
|         return with_remainders[0].remainder | ||||
|         return dict((product.id, product.remainder) for product in products) | ||||
|  |  | |||
							
								
								
									
										144
									
								
								registrasion/tests/test_batch.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								registrasion/tests/test_batch.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,144 @@ | |||
| import datetime | ||||
| import pytz | ||||
| 
 | ||||
| from django.core.exceptions import ValidationError | ||||
| 
 | ||||
| from controller_helpers import TestingCartController | ||||
| from test_cart import RegistrationCartTestCase | ||||
| 
 | ||||
| from registrasion.controllers.batch import BatchController | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from registrasion.controllers.product import ProductController | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
| UTC = pytz.timezone('UTC') | ||||
| 
 | ||||
| 
 | ||||
| class BatchTestCase(RegistrationCartTestCase): | ||||
| 
 | ||||
|     def test_no_caches_outside_of_batches(self): | ||||
|         cache_1 = BatchController.get_cache(self.USER_1) | ||||
|         cache_2 = BatchController.get_cache(self.USER_2) | ||||
| 
 | ||||
|         # Identity testing is important here | ||||
|         self.assertIsNot(cache_1, cache_2) | ||||
| 
 | ||||
|     def test_cache_clears_at_batch_exit(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_1 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         cache_2 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         self.assertIsNot(cache_1, cache_2) | ||||
| 
 | ||||
|     def test_caches_identical_within_nestings(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_1 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_2): | ||||
|                 cache_2 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|             cache_3 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         self.assertIs(cache_1, cache_2) | ||||
|         self.assertIs(cache_2, cache_3) | ||||
| 
 | ||||
|     def test_caches_are_independent_for_different_users(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_1 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_2): | ||||
|                 cache_2 = BatchController.get_cache(self.USER_2) | ||||
| 
 | ||||
|         self.assertIsNot(cache_1, cache_2) | ||||
| 
 | ||||
|     def test_cache_clears_are_independent_for_different_users(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_1 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_2): | ||||
|                 cache_2 = BatchController.get_cache(self.USER_2) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_2): | ||||
|                 cache_3 = BatchController.get_cache(self.USER_2) | ||||
| 
 | ||||
|             cache_4 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         self.assertIs(cache_1, cache_4) | ||||
|         self.assertIsNot(cache_1, cache_2) | ||||
|         self.assertIsNot(cache_2, cache_3) | ||||
| 
 | ||||
|     def test_new_caches_for_new_batches(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_1 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cache_2 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_1): | ||||
|                 cache_3 = BatchController.get_cache(self.USER_1) | ||||
| 
 | ||||
|         self.assertIs(cache_2, cache_3) | ||||
|         self.assertIsNot(cache_1, cache_2) | ||||
| 
 | ||||
|     def test_memoisation_happens_in_batch_context(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             output_1 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_1): | ||||
|                 output_2 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|         self.assertIs(output_1, output_2) | ||||
| 
 | ||||
|     def test_memoisaion_does_not_happen_outside_batch_context(self): | ||||
|         output_1 = self._memoiseme(self.USER_1) | ||||
|         output_2 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|         self.assertIsNot(output_1, output_2) | ||||
| 
 | ||||
|     def test_memoisation_is_user_independent(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             output_1 = self._memoiseme(self.USER_1) | ||||
|             with BatchController.batch(self.USER_2): | ||||
|                 output_2 = self._memoiseme(self.USER_2) | ||||
|                 output_3 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|         self.assertIsNot(output_1, output_2) | ||||
|         self.assertIs(output_1, output_3) | ||||
| 
 | ||||
|     def test_memoisation_clears_outside_batches(self): | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             output_1 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             output_2 = self._memoiseme(self.USER_1) | ||||
| 
 | ||||
|         self.assertIsNot(output_1, output_2) | ||||
| 
 | ||||
|     @classmethod | ||||
|     @BatchController.memoise | ||||
|     def _memoiseme(self, user): | ||||
|         return object() | ||||
| 
 | ||||
|     def test_batch_end_functionality_is_called(self): | ||||
|         class Ender(object): | ||||
|             end_count = 0 | ||||
|             def end_batch(self): | ||||
|                 self.end_count += 1 | ||||
| 
 | ||||
|         @BatchController.memoise | ||||
|         def get_ender(user): | ||||
|             return Ender() | ||||
| 
 | ||||
|         # end_batch should get called once on exiting the batch | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             ender = get_ender(self.USER_1) | ||||
|         self.assertEquals(1, ender.end_count) | ||||
| 
 | ||||
|         # end_batch should get called once on exiting the batch | ||||
|         # no matter how deep the object gets cached | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             with BatchController.batch(self.USER_1): | ||||
|                 ender = get_ender(self.USER_1) | ||||
|         self.assertEquals(1, ender.end_count) | ||||
|  | @ -12,6 +12,7 @@ from registrasion.models import commerce | |||
| from registrasion.models import conditions | ||||
| from registrasion.models import inventory | ||||
| from registrasion.models import people | ||||
| from registrasion.controllers.batch import BatchController | ||||
| from registrasion.controllers.product import ProductController | ||||
| 
 | ||||
| from controller_helpers import TestingCartController | ||||
|  | @ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase): | |||
| 
 | ||||
|     def test_available_products_respects_product_limits(self): | ||||
|         self.__available_products_test(self.PROD_4, 6) | ||||
| 
 | ||||
|     def test_cart_controller_for_user_is_memoised(self): | ||||
|         # - that for_user is memoised | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             cart = TestingCartController.for_user(self.USER_1) | ||||
|             cart_2 = TestingCartController.for_user(self.USER_1) | ||||
|         self.assertIs(cart, cart_2) | ||||
| 
 | ||||
|     def test_cart_revision_does_not_increment_if_not_modified(self): | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         rev_0 = cart.cart.revision | ||||
| 
 | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             # Memoise the cart | ||||
|             same_cart = TestingCartController.for_user(self.USER_1) | ||||
|             # Do nothing on exit | ||||
|          | ||||
|         rev_1 = self.reget(cart.cart).revision | ||||
|         self.assertEqual(rev_0, rev_1) | ||||
| 
 | ||||
|     def test_cart_revision_only_increments_at_end_of_batches(self): | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         rev_0 = cart.cart.revision | ||||
| 
 | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             # Memoise the cart | ||||
|             same_cart = TestingCartController.for_user(self.USER_1) | ||||
|             same_cart.add_to_cart(self.PROD_1, 1) | ||||
|             rev_1 = self.reget(same_cart.cart).revision | ||||
| 
 | ||||
|         rev_2 = self.reget(cart.cart).revision | ||||
| 
 | ||||
|         self.assertEqual(rev_0, rev_1) | ||||
|         self.assertNotEqual(rev_0, rev_2) | ||||
| 
 | ||||
|     def test_cart_discounts_only_calculated_at_end_of_batches(self): | ||||
|         def count_discounts(cart): | ||||
|             return cart.cart.discountitem_set.count() | ||||
| 
 | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         self.make_discount_ceiling("FLOOZLE") | ||||
|         count_0 = count_discounts(cart) | ||||
| 
 | ||||
|         with BatchController.batch(self.USER_1): | ||||
|             # Memoise the cart | ||||
|             same_cart = TestingCartController.for_user(self.USER_1) | ||||
| 
 | ||||
|             with BatchController.batch(self.USER_1): | ||||
|                 # Memoise the cart | ||||
|                 same_cart_2 = TestingCartController.for_user(self.USER_1) | ||||
| 
 | ||||
|                 same_cart_2.add_to_cart(self.PROD_1, 1) | ||||
|                 count_1 = count_discounts(same_cart_2) | ||||
| 
 | ||||
|             count_2 = count_discounts(same_cart) | ||||
| 
 | ||||
|         count_3 = count_discounts(cart) | ||||
| 
 | ||||
|         self.assertEqual(0, count_0) | ||||
|         self.assertEqual(0, count_1) | ||||
|         self.assertEqual(0, count_2) | ||||
|         self.assertEqual(1, count_3) | ||||
|  |  | |||
|  | @ -5,9 +5,10 @@ from registrasion import util | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import inventory | ||||
| from registrasion.models import people | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from registrasion.controllers.batch import BatchController | ||||
| from registrasion.controllers.cart import CartController | ||||
| from registrasion.controllers.credit_note import CreditNoteController | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from registrasion.controllers.invoice import InvoiceController | ||||
| from registrasion.controllers.product import ProductController | ||||
| from registrasion.exceptions import CartValidationError | ||||
|  | @ -170,18 +171,18 @@ def guided_registration(request): | |||
|             category__in=cats, | ||||
|         ).select_related("category") | ||||
| 
 | ||||
|         available_products = set(ProductController.available_products( | ||||
|             request.user, | ||||
|             products=all_products, | ||||
|         )) | ||||
|         with BatchController.batch(request.user): | ||||
|             available_products = set(ProductController.available_products( | ||||
|                 request.user, | ||||
|                 products=all_products, | ||||
|             )) | ||||
| 
 | ||||
|         if len(available_products) == 0: | ||||
|             # We've filled in every category | ||||
|             attendee.completed_registration = True | ||||
|             attendee.save() | ||||
|             return next_step | ||||
|             if len(available_products) == 0: | ||||
|                 # We've filled in every category | ||||
|                 attendee.completed_registration = True | ||||
|                 attendee.save() | ||||
|                 return next_step | ||||
| 
 | ||||
|         with CartController.operations_batch(request.user): | ||||
|             for category in cats: | ||||
|                 products = [ | ||||
|                     i for i in available_products | ||||
|  | @ -345,20 +346,21 @@ def product_category(request, category_id): | |||
|     category_id = int(category_id)  # Routing is [0-9]+ | ||||
|     category = inventory.Category.objects.get(pk=category_id) | ||||
| 
 | ||||
|     products = ProductController.available_products( | ||||
|         request.user, | ||||
|         category=category, | ||||
|     ) | ||||
| 
 | ||||
|     if not products: | ||||
|         messages.warning( | ||||
|             request, | ||||
|             "There are no products available from category: " + category.name, | ||||
|     with BatchController.batch(request.user): | ||||
|         products = ProductController.available_products( | ||||
|             request.user, | ||||
|             category=category, | ||||
|         ) | ||||
|         return redirect("dashboard") | ||||
| 
 | ||||
|     p = _handle_products(request, category, products, PRODUCTS_FORM_PREFIX) | ||||
|     products_form, discounts, products_handled = p | ||||
|         if not products: | ||||
|             messages.warning( | ||||
|                 request, | ||||
|                 "There are no products available from category: " + category.name, | ||||
|             ) | ||||
|             return redirect("dashboard") | ||||
| 
 | ||||
|         p = _handle_products(request, category, products, PRODUCTS_FORM_PREFIX) | ||||
|         products_form, discounts, products_handled = p | ||||
| 
 | ||||
|     if request.POST and not voucher_handled and not products_form.errors: | ||||
|         # Only return to the dashboard if we didn't add a voucher code | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Christopher Neugebauer
						Christopher Neugebauer