Merge branch 'query-optimisation'
This commit is contained in:
		
						commit
						6956c78b0d
					
				
					 14 changed files with 1064 additions and 449 deletions
				
			
		|  | @ -1,6 +1,6 @@ | |||
| import collections | ||||
| import contextlib | ||||
| import datetime | ||||
| import discount | ||||
| import functools | ||||
| import itertools | ||||
| 
 | ||||
|  | @ -8,6 +8,7 @@ from django.core.exceptions import ObjectDoesNotExist | |||
| from django.core.exceptions import ValidationError | ||||
| from django.db import transaction | ||||
| from django.db.models import Max | ||||
| from django.db.models import Q | ||||
| from django.utils import timezone | ||||
| 
 | ||||
| from registrasion.exceptions import CartValidationError | ||||
|  | @ -15,19 +16,27 @@ from registrasion.models import commerce | |||
| from registrasion.models import conditions | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| from category import CategoryController | ||||
| from conditions import ConditionController | ||||
| from product import ProductController | ||||
| from .category import CategoryController | ||||
| from .discount import DiscountController | ||||
| from .flag import FlagController | ||||
| from .product import ProductController | ||||
| 
 | ||||
| 
 | ||||
| def _modifies_cart(func): | ||||
|     ''' Decorator that makes the wrapped function raise ValidationError | ||||
|     if we're doing something that could modify the cart. ''' | ||||
|     if we're doing something that could modify the cart. | ||||
| 
 | ||||
|     It also wraps the execution of this function in a database transaction, | ||||
|     and marks the boundaries of a cart operations batch. | ||||
|     ''' | ||||
| 
 | ||||
|     @functools.wraps(func) | ||||
|     def inner(self, *a, **k): | ||||
|         self._fail_if_cart_is_not_active() | ||||
|         return func(self, *a, **k) | ||||
|         with transaction.atomic(): | ||||
|             with CartController.operations_batch(self.cart.user) as mark: | ||||
|                 mark.mark = True  # Marker that we've modified the cart | ||||
|                 return func(self, *a, **k) | ||||
| 
 | ||||
|     return inner | ||||
| 
 | ||||
|  | @ -55,13 +64,65 @@ class CartController(object): | |||
|             ) | ||||
|         return cls(existing) | ||||
| 
 | ||||
|     # Marks the carts that are currently in batches | ||||
|     _FOR_USER = {} | ||||
|     _BATCH_COUNT = collections.defaultdict(int) | ||||
|     _MODIFIED_CARTS = set() | ||||
| 
 | ||||
|     class _ModificationMarker(object): | ||||
|         pass | ||||
| 
 | ||||
|     @classmethod | ||||
|     @contextlib.contextmanager | ||||
|     def operations_batch(cls, user): | ||||
|         ''' Marks the boundary for a batch of operations on a user's cart. | ||||
| 
 | ||||
|         These markers can be nested. Only on exiting the outermost marker will | ||||
|         a batch be ended. | ||||
| 
 | ||||
|         When a batch is ended, discounts are recalculated, and the cart's | ||||
|         revision is increased. | ||||
|         ''' | ||||
| 
 | ||||
|         if user not in cls._FOR_USER: | ||||
|             _ctrl = cls.for_user(user) | ||||
|             cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id) | ||||
| 
 | ||||
|         ctrl, _id = cls._FOR_USER[user] | ||||
| 
 | ||||
|         cls._BATCH_COUNT[_id] += 1 | ||||
|         try: | ||||
|             success = False | ||||
| 
 | ||||
|             marker = cls._ModificationMarker() | ||||
|             yield marker | ||||
| 
 | ||||
|             if hasattr(marker, "mark"): | ||||
|                 cls._MODIFIED_CARTS.add(_id) | ||||
| 
 | ||||
|             success = True | ||||
|         finally: | ||||
| 
 | ||||
|             cls._BATCH_COUNT[_id] -= 1 | ||||
| 
 | ||||
|             # Only end on the outermost batch marker, and only if | ||||
|             # it excited cleanly, and a modification occurred | ||||
|             modified = _id in cls._MODIFIED_CARTS | ||||
|             outermost = cls._BATCH_COUNT[_id] == 0 | ||||
|             if modified and outermost and success: | ||||
|                 ctrl._end_batch() | ||||
|                 cls._MODIFIED_CARTS.remove(_id) | ||||
| 
 | ||||
|             # Clear out the cache on the outermost operation | ||||
|             if outermost: | ||||
|                 del cls._FOR_USER[user] | ||||
| 
 | ||||
|     def _fail_if_cart_is_not_active(self): | ||||
|         self.cart.refresh_from_db() | ||||
|         if self.cart.status != commerce.Cart.STATUS_ACTIVE: | ||||
|             raise ValidationError("You can only amend active carts.") | ||||
| 
 | ||||
|     @_modifies_cart | ||||
|     def extend_reservation(self): | ||||
|     def _autoextend_reservation(self): | ||||
|         ''' Updates the cart's time last updated value, which is used to | ||||
|         determine whether the cart has reserved the items and discounts it | ||||
|         holds. ''' | ||||
|  | @ -83,21 +144,25 @@ class CartController(object): | |||
|         self.cart.time_last_updated = timezone.now() | ||||
|         self.cart.reservation_duration = max(reservations) | ||||
| 
 | ||||
|     @_modifies_cart | ||||
|     def end_batch(self): | ||||
|     def _end_batch(self): | ||||
|         ''' Performs operations that occur occur at the end of a batch of | ||||
|         product changes/voucher applications etc. | ||||
|         THIS SHOULD BE PRIVATE | ||||
| 
 | ||||
|         You need to call this after you've finished modifying the user's cart. | ||||
|         This is normally done by wrapping a block of code using | ||||
|         ``operations_batch``. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         self.recalculate_discounts() | ||||
|         self.cart.refresh_from_db() | ||||
| 
 | ||||
|         self.extend_reservation() | ||||
|         self._recalculate_discounts() | ||||
| 
 | ||||
|         self._autoextend_reservation() | ||||
|         self.cart.revision += 1 | ||||
|         self.cart.save() | ||||
| 
 | ||||
|     @_modifies_cart | ||||
|     @transaction.atomic | ||||
|     def set_quantities(self, product_quantities): | ||||
|         ''' Sets the quantities on each of the products on each of the | ||||
|         products specified. Raises an exception (ValidationError) if a limit | ||||
|  | @ -122,24 +187,28 @@ class CartController(object): | |||
|         # Validate that the limits we're adding are OK | ||||
|         self._test_limits(all_product_quantities) | ||||
| 
 | ||||
|         new_items = [] | ||||
|         products = [] | ||||
|         for product, quantity in product_quantities: | ||||
|             try: | ||||
|                 product_item = commerce.ProductItem.objects.get( | ||||
|                     cart=self.cart, | ||||
|                     product=product, | ||||
|                 ) | ||||
|                 product_item.quantity = quantity | ||||
|                 product_item.save() | ||||
|             except ObjectDoesNotExist: | ||||
|                 commerce.ProductItem.objects.create( | ||||
|                     cart=self.cart, | ||||
|                     product=product, | ||||
|                     quantity=quantity, | ||||
|                 ) | ||||
|             products.append(product) | ||||
| 
 | ||||
|         items_in_cart.filter(quantity=0).delete() | ||||
|             if quantity == 0: | ||||
|                 continue | ||||
| 
 | ||||
|         self.end_batch() | ||||
|             item = commerce.ProductItem( | ||||
|                 cart=self.cart, | ||||
|                 product=product, | ||||
|                 quantity=quantity, | ||||
|             ) | ||||
|             new_items.append(item) | ||||
| 
 | ||||
|         to_delete = ( | ||||
|             Q(quantity=0) | | ||||
|             Q(product__in=products) | ||||
|         ) | ||||
| 
 | ||||
|         items_in_cart.filter(to_delete).delete() | ||||
|         commerce.ProductItem.objects.bulk_create(new_items) | ||||
| 
 | ||||
|     def _test_limits(self, product_quantities): | ||||
|         ''' Tests that the quantity changes we intend to make do not violate | ||||
|  | @ -147,13 +216,17 @@ class CartController(object): | |||
| 
 | ||||
|         errors = [] | ||||
| 
 | ||||
|         # Pre-annotate products | ||||
|         products = [p for (p, q) in product_quantities] | ||||
|         r = ProductController.attach_user_remainders(self.cart.user, products) | ||||
|         with_remainders = dict((p, p) for p in r) | ||||
| 
 | ||||
|         # Test each product limit here | ||||
|         for product, quantity in product_quantities: | ||||
|             if quantity < 0: | ||||
|                 errors.append((product, "Value must be zero or greater.")) | ||||
| 
 | ||||
|             prod = ProductController(product) | ||||
|             limit = prod.user_quantity_remaining(self.cart.user) | ||||
|             limit = with_remainders[product].remainder | ||||
| 
 | ||||
|             if quantity > limit: | ||||
|                 errors.append(( | ||||
|  | @ -168,10 +241,13 @@ class CartController(object): | |||
|         for product, quantity in product_quantities: | ||||
|             by_cat[product.category].append((product, quantity)) | ||||
| 
 | ||||
|         # Pre-annotate categories | ||||
|         r = CategoryController.attach_user_remainders(self.cart.user, by_cat) | ||||
|         with_remainders = dict((cat, cat) for cat in r) | ||||
| 
 | ||||
|         # Test each category limit here | ||||
|         for category in by_cat: | ||||
|             ctrl = CategoryController(category) | ||||
|             limit = ctrl.user_quantity_remaining(self.cart.user) | ||||
|             limit = with_remainders[category].remainder | ||||
| 
 | ||||
|             # Get the amount so far in the cart | ||||
|             to_add = sum(i[1] for i in by_cat[category]) | ||||
|  | @ -185,7 +261,7 @@ class CartController(object): | |||
|                 )) | ||||
| 
 | ||||
|         # Test the flag conditions | ||||
|         errs = ConditionController.test_flags( | ||||
|         errs = FlagController.test_flags( | ||||
|             self.cart.user, | ||||
|             product_quantities=product_quantities, | ||||
|         ) | ||||
|  | @ -212,7 +288,6 @@ class CartController(object): | |||
| 
 | ||||
|         # If successful... | ||||
|         self.cart.vouchers.add(voucher) | ||||
|         self.end_batch() | ||||
| 
 | ||||
|     def _test_voucher(self, voucher): | ||||
|         ''' Tests whether this voucher is allowed to be applied to this cart. | ||||
|  | @ -294,6 +369,7 @@ class CartController(object): | |||
|             errors.append(ve) | ||||
| 
 | ||||
|         items = commerce.ProductItem.objects.filter(cart=cart) | ||||
|         items = items.select_related("product", "product__category") | ||||
| 
 | ||||
|         product_quantities = list((i.product, i.quantity) for i in items) | ||||
|         try: | ||||
|  | @ -307,19 +383,24 @@ class CartController(object): | |||
|             self._append_errors(errors, ve) | ||||
| 
 | ||||
|         # Validate the discounts | ||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||
|         seen_discounts = set() | ||||
|         # TODO: refactor in terms of available_discounts | ||||
|         # why aren't we doing that here?! | ||||
| 
 | ||||
|         #     def available_discounts(cls, user, categories, products): | ||||
| 
 | ||||
|         products = [i.product for i in items] | ||||
|         discounts_with_quantity = DiscountController.available_discounts( | ||||
|             user, | ||||
|             [], | ||||
|             products, | ||||
|         ) | ||||
|         discounts = set(i.discount.id for i in discounts_with_quantity) | ||||
| 
 | ||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||
|         for discount_item in discount_items: | ||||
|             discount = discount_item.discount | ||||
|             if discount in seen_discounts: | ||||
|                 continue | ||||
|             seen_discounts.add(discount) | ||||
|             real_discount = conditions.DiscountBase.objects.get_subclass( | ||||
|                 pk=discount.pk) | ||||
|             cond = ConditionController.for_condition(real_discount) | ||||
| 
 | ||||
|             if not cond.is_met(user): | ||||
|             if discount.id not in discounts: | ||||
|                 errors.append( | ||||
|                     ValidationError("Discounts are no longer available") | ||||
|                 ) | ||||
|  | @ -328,7 +409,6 @@ class CartController(object): | |||
|             raise ValidationError(errors) | ||||
| 
 | ||||
|     @_modifies_cart | ||||
|     @transaction.atomic | ||||
|     def fix_simple_errors(self): | ||||
|         ''' This attempts to fix the easy errors raised by ValidationError. | ||||
|         This includes removing items from the cart that are no longer | ||||
|  | @ -360,11 +440,9 @@ class CartController(object): | |||
| 
 | ||||
|         self.set_quantities(zeros) | ||||
| 
 | ||||
|     @_modifies_cart | ||||
|     @transaction.atomic | ||||
|     def recalculate_discounts(self): | ||||
|         ''' Calculates all of the discounts available for this product. | ||||
|         ''' | ||||
|     def _recalculate_discounts(self): | ||||
|         ''' Calculates all of the discounts available for this product.''' | ||||
| 
 | ||||
|         # Delete the existing entries. | ||||
|         commerce.DiscountItem.objects.filter(cart=self.cart).delete() | ||||
|  | @ -374,7 +452,11 @@ class CartController(object): | |||
|         ) | ||||
| 
 | ||||
|         products = [i.product for i in product_items] | ||||
|         discounts = discount.available_discounts(self.cart.user, [], products) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.cart.user, | ||||
|             [], | ||||
|             products, | ||||
|         ) | ||||
| 
 | ||||
|         # The highest-value discounts will apply to the highest-value | ||||
|         # products first. | ||||
|  |  | |||
|  | @ -1,7 +1,11 @@ | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import When | ||||
| from django.db.models import Value | ||||
| 
 | ||||
| 
 | ||||
| class AllProducts(object): | ||||
|  | @ -34,25 +38,47 @@ class CategoryController(object): | |||
| 
 | ||||
|         return set(i.category for i in available) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def attach_user_remainders(cls, user, categories): | ||||
|         ''' | ||||
| 
 | ||||
|         Return: | ||||
|             queryset(inventory.Product): A queryset containing items from | ||||
|             ``categories``, with an extra attribute -- remainder = the amount | ||||
|             of items from this category that is remaining. | ||||
|         ''' | ||||
| 
 | ||||
|         ids = [category.id for category in categories] | ||||
|         categories = inventory.Category.objects.filter(id__in=ids) | ||||
| 
 | ||||
|         cart_filter = ( | ||||
|             Q(product__productitem__cart__user=user) & | ||||
|             Q(product__productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||
|         ) | ||||
| 
 | ||||
|         quantity = When( | ||||
|             cart_filter, | ||||
|             then='product__productitem__quantity' | ||||
|         ) | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             quantity, | ||||
|             default=Value(0), | ||||
|         ) | ||||
| 
 | ||||
|         remainder = Case( | ||||
|             When(limit_per_user=None, then=Value(99999999)), | ||||
|             default=F('limit_per_user') - Sum(quantity_or_zero), | ||||
|         ) | ||||
| 
 | ||||
|         categories = categories.annotate(remainder=remainder) | ||||
| 
 | ||||
|         return categories | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' Returns the number of items from this category that the user may | ||||
|         add in the current cart. ''' | ||||
|         ''' Returns the quantity of this product that the user add in the | ||||
|         current cart. ''' | ||||
| 
 | ||||
|         cat_limit = self.category.limit_per_user | ||||
|         with_remainders = self.attach_user_remainders(user, [self.category]) | ||||
| 
 | ||||
|         if cat_limit is None: | ||||
|             # We don't need to waste the following queries | ||||
|             return 99999999 | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter( | ||||
|             user=user, | ||||
|             status=commerce.Cart.STATUS_PAID, | ||||
|         ) | ||||
| 
 | ||||
|         items = commerce.ProductItem.objects.filter( | ||||
|             cart__in=carts, | ||||
|             product__category=self.category, | ||||
|         ) | ||||
| 
 | ||||
|         cat_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 | ||||
|         return cat_limit - cat_count | ||||
|         return with_remainders[0].remainder | ||||
|  |  | |||
|  | @ -1,36 +1,27 @@ | |||
| import itertools | ||||
| import operator | ||||
| 
 | ||||
| from collections import defaultdict | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import Value | ||||
| from django.db.models import When | ||||
| from django.utils import timezone | ||||
| 
 | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| 
 | ||||
| ConditionAndRemainder = namedtuple( | ||||
|     "ConditionAndRemainder", | ||||
|     ( | ||||
|         "condition", | ||||
|         "remainder", | ||||
|     ), | ||||
| ) | ||||
| _BIG_QUANTITY = 99999999  # A big quantity | ||||
| 
 | ||||
| 
 | ||||
| class ConditionController(object): | ||||
|     ''' Base class for testing conditions that activate Flag | ||||
|     or Discount objects. ''' | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         pass | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_condition(condition): | ||||
|         CONTROLLERS = { | ||||
|     def _controllers(): | ||||
|         return { | ||||
|             conditions.CategoryFlag: CategoryConditionController, | ||||
|             conditions.IncludedProductDiscount: ProductConditionController, | ||||
|             conditions.ProductFlag: ProductConditionController, | ||||
|  | @ -42,137 +33,49 @@ class ConditionController(object): | |||
|             conditions.VoucherFlag: VoucherConditionController, | ||||
|         } | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_type(cls): | ||||
|         return ConditionController._controllers()[cls] | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_condition(condition): | ||||
|         try: | ||||
|             return CONTROLLERS[type(condition)](condition) | ||||
|             return ConditionController.for_type(type(condition))(condition) | ||||
|         except KeyError: | ||||
|             return ConditionController() | ||||
| 
 | ||||
|     SINGLE = True | ||||
|     PLURAL = False | ||||
|     NONE = True | ||||
|     SOME = False | ||||
|     MESSAGE = { | ||||
|         NONE: { | ||||
|             SINGLE: | ||||
|                 "%(items)s is no longer available to you", | ||||
|             PLURAL: | ||||
|                 "%(items)s are no longer available to you", | ||||
|         }, | ||||
|         SOME: { | ||||
|             SINGLE: | ||||
|                 "Only %(remainder)d of the following item remains: %(items)s", | ||||
|             PLURAL: | ||||
|                 "Only %(remainder)d of the following items remain: %(items)s" | ||||
|         }, | ||||
|     } | ||||
| 
 | ||||
|     @classmethod | ||||
|     def test_flags( | ||||
|             cls, user, products=None, product_quantities=None): | ||||
|         ''' Evaluates all of the flag conditions on the given products. | ||||
|     def pre_filter(cls, queryset, user): | ||||
|         ''' Returns only the flag conditions that might be available for this | ||||
|         user. It should hopefully reduce the number of queries that need to be | ||||
|         executed to determine if a flag is met. | ||||
| 
 | ||||
|         If `product_quantities` is supplied, the condition is only met if it | ||||
|         will permit the sum of the product quantities for all of the products | ||||
|         it covers. Otherwise, it will be met if at least one item can be | ||||
|         accepted. | ||||
|         If this filtration implements the same query as is_met, then you should | ||||
|         be able to implement ``is_met()`` in terms of this. | ||||
| 
 | ||||
|         If all flag conditions pass, an empty list is returned, otherwise | ||||
|         a list is returned containing all of the products that are *not | ||||
|         enabled*. ''' | ||||
|         Arguments: | ||||
| 
 | ||||
|         if products is not None and product_quantities is not None: | ||||
|             raise ValueError("Please specify only products or " | ||||
|                              "product_quantities") | ||||
|         elif products is None: | ||||
|             products = set(i[0] for i in product_quantities) | ||||
|             quantities = dict((product, quantity) | ||||
|                               for product, quantity in product_quantities) | ||||
|         elif product_quantities is None: | ||||
|             products = set(products) | ||||
|             quantities = {} | ||||
|             queryset (Queryset[c]): The canditate conditions. | ||||
| 
 | ||||
|         # Get the conditions covered by the products themselves | ||||
|         prods = ( | ||||
|             product.flagbase_set.select_subclasses() | ||||
|             for product in products | ||||
|         ) | ||||
|         # Get the conditions covered by their categories | ||||
|         cats = ( | ||||
|             category.flagbase_set.select_subclasses() | ||||
|             for category in set(product.category for product in products) | ||||
|         ) | ||||
|             user (User): The user for whom we're testing these conditions. | ||||
| 
 | ||||
|         if products: | ||||
|             # Simplify the query. | ||||
|             all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) | ||||
|         else: | ||||
|             all_conditions = [] | ||||
|         Returns: | ||||
|             Queryset[c]: A subset of the conditions that pass the pre-filter | ||||
|                 test for this user. | ||||
| 
 | ||||
|         # All disable-if-false conditions on a product need to be met | ||||
|         do_not_disable = defaultdict(lambda: True) | ||||
|         # At least one enable-if-true condition on a product must be met | ||||
|         do_enable = defaultdict(lambda: False) | ||||
|         # (if either sort of condition is present) | ||||
|         ''' | ||||
| 
 | ||||
|         messages = {} | ||||
|         # Default implementation does NOTHING. | ||||
|         return queryset | ||||
| 
 | ||||
|         for condition in all_conditions: | ||||
|             cond = cls.for_condition(condition) | ||||
|             remainder = cond.user_quantity_remaining(user) | ||||
|     def passes_filter(self, user): | ||||
|         ''' Returns true if the condition passes the filter ''' | ||||
| 
 | ||||
|             # Get all products covered by this condition, and the products | ||||
|             # from the categories covered by this condition | ||||
|             cond_products = condition.products.all() | ||||
|             from_category = inventory.Product.objects.filter( | ||||
|                 category__in=condition.categories.all(), | ||||
|             ).all() | ||||
|             all_products = cond_products | from_category | ||||
|             all_products = all_products.select_related("category") | ||||
|             # Remove the products that we aren't asking about | ||||
|             all_products = [ | ||||
|                 product | ||||
|                 for product in all_products | ||||
|                 if product in products | ||||
|             ] | ||||
|         cls = type(self.condition) | ||||
|         qs = cls.objects.filter(pk=self.condition.id) | ||||
|         return self.condition in self.pre_filter(qs, user) | ||||
| 
 | ||||
|             if quantities: | ||||
|                 consumed = sum(quantities[i] for i in all_products) | ||||
|             else: | ||||
|                 consumed = 1 | ||||
|             met = consumed <= remainder | ||||
| 
 | ||||
|             if not met: | ||||
|                 items = ", ".join(str(product) for product in all_products) | ||||
|                 base = cls.MESSAGE[remainder == 0][len(all_products) == 1] | ||||
|                 message = base % {"items": items, "remainder": remainder} | ||||
| 
 | ||||
|             for product in all_products: | ||||
|                 if condition.is_disable_if_false: | ||||
|                     do_not_disable[product] &= met | ||||
|                 else: | ||||
|                     do_enable[product] |= met | ||||
| 
 | ||||
|                 if not met and product not in messages: | ||||
|                     messages[product] = message | ||||
| 
 | ||||
|         valid = {} | ||||
|         for product in itertools.chain(do_not_disable, do_enable): | ||||
|             if product in do_enable: | ||||
|                 # If there's an enable-if-true, we need need of those met too. | ||||
|                 # (do_not_disable will default to true otherwise) | ||||
|                 valid[product] = do_not_disable[product] and do_enable[product] | ||||
|             elif product in do_not_disable: | ||||
|                 # If there's a disable-if-false condition, all must be met | ||||
|                 valid[product] = do_not_disable[product] | ||||
| 
 | ||||
|         error_fields = [ | ||||
|             (product, messages[product]) | ||||
|             for product in valid if not valid[product] | ||||
|         ] | ||||
| 
 | ||||
|         return error_fields | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|     def user_quantity_remaining(self, user, filtered=False): | ||||
|         ''' Returns the number of items covered by this flag condition the | ||||
|         user can add to the current cart. This default implementation returns | ||||
|         a big number if is_met() is true, otherwise 0. | ||||
|  | @ -180,144 +83,210 @@ class ConditionController(object): | |||
|         Either this method, or is_met() must be overridden in subclasses. | ||||
|         ''' | ||||
| 
 | ||||
|         return 99999999 if self.is_met(user) else 0 | ||||
|         return _BIG_QUANTITY if self.is_met(user, filtered) else 0 | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|     def is_met(self, user, filtered=False): | ||||
|         ''' Returns True if this flag condition is met, otherwise returns | ||||
|         False. | ||||
| 
 | ||||
|         Either this method, or user_quantity_remaining() must be overridden | ||||
|         in subclasses. | ||||
| 
 | ||||
|         Arguments: | ||||
| 
 | ||||
|             user (User): The user for whom this test must be met. | ||||
| 
 | ||||
|             filter (bool): If true, this condition was part of a queryset | ||||
|                 returned by pre_filter() for this user. | ||||
| 
 | ||||
|         ''' | ||||
|         return self.user_quantity_remaining(user) > 0 | ||||
|         return self.user_quantity_remaining(user, filtered) > 0 | ||||
| 
 | ||||
| 
 | ||||
| class CategoryConditionController(ConditionController): | ||||
| class IsMetByFilter(object): | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     def is_met(self, user, filtered=False): | ||||
|         ''' Returns True if this flag condition is met, otherwise returns | ||||
|         False. It determines if the condition is met by calling pre_filter | ||||
|         with a queryset containing only self.condition. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has a product from a category that invokes | ||||
|         this condition in one of their carts ''' | ||||
|         if filtered: | ||||
|             return True  # Why query again? | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter(user=user) | ||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | ||||
|         enabling_products = inventory.Product.objects.filter( | ||||
|             category=self.condition.enabling_category, | ||||
|         return self.passes_filter(user) | ||||
| 
 | ||||
| 
 | ||||
| class RemainderSetByFilter(object): | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user, filtered=True): | ||||
|         ''' returns 0 if the date range is violated, otherwise, it will return | ||||
|         the quantity remaining under the stock limit. | ||||
| 
 | ||||
|         The filter for this condition must add an annotation called "remainder" | ||||
|         in order for this to work. | ||||
|         ''' | ||||
| 
 | ||||
|         if filtered: | ||||
|             if hasattr(self.condition, "remainder"): | ||||
|                 return self.condition.remainder | ||||
| 
 | ||||
|         # Mark self.condition with a remainder | ||||
|         qs = type(self.condition).objects.filter(pk=self.condition.id) | ||||
|         qs = self.pre_filter(qs, user) | ||||
| 
 | ||||
|         if len(qs) > 0: | ||||
|             return qs[0].remainder | ||||
|         else: | ||||
|             return 0 | ||||
| 
 | ||||
| 
 | ||||
| class CategoryConditionController(IsMetByFilter, ConditionController): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has a | ||||
|         product from a category invoking that item's condition in one of their | ||||
|         carts. ''' | ||||
| 
 | ||||
|         in_user_carts = Q( | ||||
|             enabling_category__product__productitem__cart__user=user | ||||
|         ) | ||||
|         products_count = commerce.ProductItem.objects.filter( | ||||
|             cart__in=carts, | ||||
|             product__in=enabling_products, | ||||
|         ).count() | ||||
|         return products_count > 0 | ||||
|         released = commerce.Cart.STATUS_RELEASED | ||||
|         in_released_carts = Q( | ||||
|             enabling_category__product__productitem__cart__status=released | ||||
|         ) | ||||
|         queryset = queryset.filter(in_user_carts) | ||||
|         queryset = queryset.exclude(in_released_carts) | ||||
| 
 | ||||
|         return queryset | ||||
| 
 | ||||
| 
 | ||||
| class ProductConditionController(ConditionController): | ||||
| class ProductConditionController(IsMetByFilter, ConditionController): | ||||
|     ''' Condition tests for ProductFlag and | ||||
|     IncludedProductDiscount. ''' | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has a | ||||
|         product invoking that item's condition in one of their carts. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has a product that invokes this | ||||
|         condition in one of their carts ''' | ||||
|         in_user_carts = Q(enabling_products__productitem__cart__user=user) | ||||
|         released = commerce.Cart.STATUS_RELEASED | ||||
|         in_released_carts = Q( | ||||
|             enabling_products__productitem__cart__status=released | ||||
|         ) | ||||
|         queryset = queryset.filter(in_user_carts) | ||||
|         queryset = queryset.exclude(in_released_carts) | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter(user=user) | ||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | ||||
|         products_count = commerce.ProductItem.objects.filter( | ||||
|             cart__in=carts, | ||||
|             product__in=self.condition.enabling_products.all(), | ||||
|         ).count() | ||||
|         return products_count > 0 | ||||
|         return queryset | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitConditionController(ConditionController): | ||||
| class TimeOrStockLimitConditionController( | ||||
|             RemainderSetByFilter, | ||||
|             ConditionController, | ||||
|         ): | ||||
|     ''' Common condition tests for TimeOrStockLimit Flag and | ||||
|     Discount.''' | ||||
| 
 | ||||
|     def __init__(self, ceiling): | ||||
|         self.ceiling = ceiling | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the date falls into | ||||
|         any specified range, but not yet where the stock limit is not yet | ||||
|         reached.''' | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' returns 0 if the date range is violated, otherwise, it will return | ||||
|         the quantity remaining under the stock limit. ''' | ||||
| 
 | ||||
|         # Test date range | ||||
|         if not self._test_date_range(): | ||||
|             return 0 | ||||
| 
 | ||||
|         return self._get_remaining_stock(user) | ||||
| 
 | ||||
|     def _test_date_range(self): | ||||
|         now = timezone.now() | ||||
| 
 | ||||
|         if self.ceiling.start_time is not None: | ||||
|             if now < self.ceiling.start_time: | ||||
|                 return False | ||||
|         # Keep items with no start time, or start time not yet met. | ||||
|         queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) | ||||
|         queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) | ||||
| 
 | ||||
|         if self.ceiling.end_time is not None: | ||||
|             if now > self.ceiling.end_time: | ||||
|                 return False | ||||
|         # Filter out items that have been reserved beyond the limits | ||||
|         quantity_or_zero = self._calculate_quantities(user) | ||||
| 
 | ||||
|         return True | ||||
|         remainder = Case( | ||||
|             When(limit=None, then=Value(_BIG_QUANTITY)), | ||||
|             default=F("limit") - Sum(quantity_or_zero), | ||||
|         ) | ||||
| 
 | ||||
|     def _get_remaining_stock(self, user): | ||||
|         ''' Returns the stock that remains under this ceiling, excluding the | ||||
|         user's current cart. ''' | ||||
|         queryset = queryset.annotate(remainder=remainder) | ||||
|         queryset = queryset.filter(remainder__gt=0) | ||||
| 
 | ||||
|         if self.ceiling.limit is None: | ||||
|             return 99999999 | ||||
|         return queryset | ||||
| 
 | ||||
|         # We care about all reserved carts, but not the user's current cart | ||||
|     @classmethod | ||||
|     def _relevant_carts(cls, user): | ||||
|         reserved_carts = commerce.Cart.reserved_carts() | ||||
|         reserved_carts = reserved_carts.exclude( | ||||
|             user=user, | ||||
|             status=commerce.Cart.STATUS_ACTIVE, | ||||
|         ) | ||||
| 
 | ||||
|         items = self._items() | ||||
|         items = items.filter(cart__in=reserved_carts) | ||||
|         count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 | ||||
| 
 | ||||
|         return self.ceiling.limit - count | ||||
|         return reserved_carts | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitFlagController( | ||||
|         TimeOrStockLimitConditionController): | ||||
| 
 | ||||
|     def _items(self): | ||||
|         category_products = inventory.Product.objects.filter( | ||||
|             category__in=self.ceiling.categories.all(), | ||||
|         ) | ||||
|         products = self.ceiling.products.all() | category_products | ||||
|     @classmethod | ||||
|     def _calculate_quantities(cls, user): | ||||
|         reserved_carts = cls._relevant_carts(user) | ||||
| 
 | ||||
|         product_items = commerce.ProductItem.objects.filter( | ||||
|             product__in=products.all(), | ||||
|         # Calculate category lines | ||||
|         item_cats = F('categories__product__productitem__product__category') | ||||
|         reserved_category_products = ( | ||||
|             Q(categories=item_cats) & | ||||
|             Q(categories__product__productitem__cart__in=reserved_carts) | ||||
|         ) | ||||
|         return product_items | ||||
| 
 | ||||
|         # Calculate product lines | ||||
|         reserved_products = ( | ||||
|             Q(products=F('products__productitem__product')) & | ||||
|             Q(products__productitem__cart__in=reserved_carts) | ||||
|         ) | ||||
| 
 | ||||
|         category_quantity_in_reserved_carts = When( | ||||
|             reserved_category_products, | ||||
|             then="categories__product__productitem__quantity", | ||||
|         ) | ||||
| 
 | ||||
|         product_quantity_in_reserved_carts = When( | ||||
|             reserved_products, | ||||
|             then="products__productitem__quantity", | ||||
|         ) | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             category_quantity_in_reserved_carts, | ||||
|             product_quantity_in_reserved_carts, | ||||
|             default=Value(0), | ||||
|         ) | ||||
| 
 | ||||
|         return quantity_or_zero | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | ||||
| 
 | ||||
|     def _items(self): | ||||
|         discount_items = commerce.DiscountItem.objects.filter( | ||||
|             discount=self.ceiling, | ||||
|     @classmethod | ||||
|     def _calculate_quantities(cls, user): | ||||
|         reserved_carts = cls._relevant_carts(user) | ||||
| 
 | ||||
|         quantity_in_reserved_carts = When( | ||||
|             discountitem__cart__in=reserved_carts, | ||||
|             then="discountitem__quantity" | ||||
|         ) | ||||
|         return discount_items | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             quantity_in_reserved_carts, | ||||
|             default=Value(0) | ||||
|         ) | ||||
| 
 | ||||
|         return quantity_or_zero | ||||
| 
 | ||||
| 
 | ||||
| class VoucherConditionController(ConditionController): | ||||
| class VoucherConditionController(IsMetByFilter, ConditionController): | ||||
|     ''' Condition test for VoucherFlag and VoucherDiscount.''' | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has entered | ||||
|         a voucher that invokes that item's condition in one of their carts. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has the given voucher attached. ''' | ||||
|         carts_count = commerce.Cart.objects.filter( | ||||
|             user=user, | ||||
|             vouchers=self.condition.voucher, | ||||
|         ).count() | ||||
|         return carts_count > 0 | ||||
|         return queryset.filter(voucher__cart__user=user) | ||||
|  |  | |||
|  | @ -4,7 +4,11 @@ from conditions import ConditionController | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import Value | ||||
| from django.db.models import When | ||||
| 
 | ||||
| 
 | ||||
| class DiscountAndQuantity(object): | ||||
|  | @ -38,80 +42,158 @@ class DiscountAndQuantity(object): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def available_discounts(user, categories, products): | ||||
|     ''' Returns all discounts available to this user for the given categories | ||||
|     and products. The discounts also list the available quantity for this user, | ||||
|     not including products that are pending purchase. ''' | ||||
| class DiscountController(object): | ||||
| 
 | ||||
|     # discounts that match provided categories | ||||
|     category_discounts = conditions.DiscountForCategory.objects.filter( | ||||
|         category__in=categories | ||||
|     ) | ||||
|     # discounts that match provided products | ||||
|     product_discounts = conditions.DiscountForProduct.objects.filter( | ||||
|         product__in=products | ||||
|     ) | ||||
|     # discounts that match categories for provided products | ||||
|     product_category_discounts = conditions.DiscountForCategory.objects.filter( | ||||
|         category__in=(product.category for product in products) | ||||
|     ) | ||||
|     # (Not relevant: discounts that match products in provided categories) | ||||
|     @classmethod | ||||
|     def available_discounts(cls, user, categories, products): | ||||
|         ''' Returns all discounts available to this user for the given | ||||
|         categories and products. The discounts also list the available quantity | ||||
|         for this user, not including products that are pending purchase. ''' | ||||
| 
 | ||||
|     product_discounts = product_discounts.select_related( | ||||
|         "product", | ||||
|         "product__category", | ||||
|     ) | ||||
|         filtered_clauses = cls._filtered_discounts(user, categories, products) | ||||
| 
 | ||||
|     all_category_discounts = category_discounts | product_category_discounts | ||||
|     all_category_discounts = all_category_discounts.select_related( | ||||
|         "category", | ||||
|     ) | ||||
|         discounts = [] | ||||
| 
 | ||||
|     # The set of all potential discounts | ||||
|     potential_discounts = set(itertools.chain( | ||||
|         product_discounts, | ||||
|         all_category_discounts, | ||||
|     )) | ||||
|         # Markers so that we don't need to evaluate given conditions | ||||
|         # more than once | ||||
|         accepted_discounts = set() | ||||
|         failed_discounts = set() | ||||
| 
 | ||||
|     discounts = [] | ||||
|         for clause in filtered_clauses: | ||||
|             discount = clause.discount | ||||
|             cond = ConditionController.for_condition(discount) | ||||
| 
 | ||||
|     # Markers so that we don't need to evaluate given conditions more than once | ||||
|     accepted_discounts = set() | ||||
|     failed_discounts = set() | ||||
|             past_use_count = clause.past_use_count | ||||
|             if past_use_count >= clause.quantity: | ||||
|                 # This clause has exceeded its use count | ||||
|                 pass | ||||
|             elif discount not in failed_discounts: | ||||
|                 # This clause is still available | ||||
|                 is_accepted = discount in accepted_discounts | ||||
|                 if is_accepted or cond.is_met(user, filtered=True): | ||||
|                     # This clause is valid for this user | ||||
|                     discounts.append(DiscountAndQuantity( | ||||
|                         discount=discount, | ||||
|                         clause=clause, | ||||
|                         quantity=clause.quantity - past_use_count, | ||||
|                     )) | ||||
|                     accepted_discounts.add(discount) | ||||
|                 else: | ||||
|                     # This clause is not valid for this user | ||||
|                     failed_discounts.add(discount) | ||||
|         return discounts | ||||
| 
 | ||||
|     for discount in potential_discounts: | ||||
|         real_discount = conditions.DiscountBase.objects.get_subclass( | ||||
|             pk=discount.discount.pk, | ||||
|     @classmethod | ||||
|     def _filtered_discounts(cls, user, categories, products): | ||||
|         ''' | ||||
| 
 | ||||
|         Returns: | ||||
|             Sequence[discountbase]: All discounts that passed the filter | ||||
|             function. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         types = list(ConditionController._controllers()) | ||||
|         discounttypes = [ | ||||
|             i for i in types if issubclass(i, conditions.DiscountBase) | ||||
|         ] | ||||
| 
 | ||||
|         # discounts that match provided categories | ||||
|         category_discounts = conditions.DiscountForCategory.objects.filter( | ||||
|             category__in=categories | ||||
|         ) | ||||
|         cond = ConditionController.for_condition(real_discount) | ||||
| 
 | ||||
|         # Count the past uses of the given discount item. | ||||
|         # If this user has exceeded the limit for the clause, this clause | ||||
|         # is not available any more. | ||||
|         past_uses = commerce.DiscountItem.objects.filter( | ||||
|             cart__user=user, | ||||
|             cart__status=commerce.Cart.STATUS_PAID,  # Only past carts count | ||||
|             discount=real_discount, | ||||
|         # discounts that match provided products | ||||
|         product_discounts = conditions.DiscountForProduct.objects.filter( | ||||
|             product__in=products | ||||
|         ) | ||||
|         agg = past_uses.aggregate(Sum("quantity")) | ||||
|         past_use_count = agg["quantity__sum"] | ||||
|         if past_use_count is None: | ||||
|             past_use_count = 0 | ||||
|         # discounts that match categories for provided products | ||||
|         product_category_discounts = conditions.DiscountForCategory.objects | ||||
|         product_category_discounts = product_category_discounts.filter( | ||||
|             category__in=(product.category for product in products) | ||||
|         ) | ||||
|         # (Not relevant: discounts that match products in provided categories) | ||||
| 
 | ||||
|         if past_use_count >= discount.quantity: | ||||
|             # This clause has exceeded its use count | ||||
|             pass | ||||
|         elif real_discount not in failed_discounts: | ||||
|             # This clause is still available | ||||
|             if real_discount in accepted_discounts or cond.is_met(user): | ||||
|                 # This clause is valid for this user | ||||
|                 discounts.append(DiscountAndQuantity( | ||||
|                     discount=real_discount, | ||||
|                     clause=discount, | ||||
|                     quantity=discount.quantity - past_use_count, | ||||
|                 )) | ||||
|                 accepted_discounts.add(real_discount) | ||||
|             else: | ||||
|                 # This clause is not valid for this user | ||||
|                 failed_discounts.add(real_discount) | ||||
|     return discounts | ||||
|         product_discounts = product_discounts.select_related( | ||||
|             "product", | ||||
|             "product__category", | ||||
|         ) | ||||
| 
 | ||||
|         all_category_discounts = ( | ||||
|             category_discounts | product_category_discounts | ||||
|         ) | ||||
|         all_category_discounts = all_category_discounts.select_related( | ||||
|             "category", | ||||
|         ) | ||||
| 
 | ||||
|         valid_discounts = conditions.DiscountBase.objects.filter( | ||||
|             Q(discountforproduct__in=product_discounts) | | ||||
|             Q(discountforcategory__in=all_category_discounts) | ||||
|         ) | ||||
| 
 | ||||
|         all_subsets = [] | ||||
| 
 | ||||
|         for discounttype in discounttypes: | ||||
|             discounts = discounttype.objects.filter(id__in=valid_discounts) | ||||
|             ctrl = ConditionController.for_type(discounttype) | ||||
|             discounts = ctrl.pre_filter(discounts, user) | ||||
|             all_subsets.append(discounts) | ||||
| 
 | ||||
|         filtered_discounts = list(itertools.chain(*all_subsets)) | ||||
| 
 | ||||
|         # Map from discount key to itself | ||||
|         # (contains annotations needed in the future) | ||||
|         from_filter = dict((i.id, i) for i in filtered_discounts) | ||||
| 
 | ||||
|         clause_sets = ( | ||||
|             product_discounts.filter(discount__in=filtered_discounts), | ||||
|             all_category_discounts.filter(discount__in=filtered_discounts), | ||||
|         ) | ||||
| 
 | ||||
|         clause_sets = ( | ||||
|             cls._annotate_with_past_uses(i, user) for i in clause_sets | ||||
|         ) | ||||
| 
 | ||||
|         # The set of all potential discount clauses | ||||
|         discount_clauses = set(itertools.chain(*clause_sets)) | ||||
| 
 | ||||
|         # Replace discounts with the filtered ones | ||||
|         # These are the correct subclasses (saves query later on), and have | ||||
|         # correct annotations from filters if necessary. | ||||
|         for clause in discount_clauses: | ||||
|             clause.discount = from_filter[clause.discount.id] | ||||
| 
 | ||||
|         return discount_clauses | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _annotate_with_past_uses(cls, queryset, user): | ||||
|         ''' Annotates the queryset with a usage count for that discount claus | ||||
|         by the given user. ''' | ||||
| 
 | ||||
|         if queryset.model == conditions.DiscountForCategory: | ||||
|             matches = ( | ||||
|                 Q(category=F('discount__discountitem__product__category')) | ||||
|             ) | ||||
|         elif queryset.model == conditions.DiscountForProduct: | ||||
|             matches = ( | ||||
|                 Q(product=F('discount__discountitem__product')) | ||||
|             ) | ||||
| 
 | ||||
|         in_carts = ( | ||||
|             Q(discount__discountitem__cart__user=user) & | ||||
|             Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID) | ||||
|         ) | ||||
| 
 | ||||
|         past_use_quantity = When( | ||||
|             in_carts & matches, | ||||
|             then="discount__discountitem__quantity", | ||||
|         ) | ||||
| 
 | ||||
|         past_use_quantity_or_zero = Case( | ||||
|             past_use_quantity, | ||||
|             default=Value(0), | ||||
|         ) | ||||
| 
 | ||||
|         queryset = queryset.annotate( | ||||
|             past_use_count=Sum(past_use_quantity_or_zero) | ||||
|         ) | ||||
|         return queryset | ||||
|  |  | |||
							
								
								
									
										264
									
								
								registrasion/controllers/flag.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								registrasion/controllers/flag.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,264 @@ | |||
| import itertools | ||||
| import operator | ||||
| 
 | ||||
| from collections import defaultdict | ||||
| from collections import namedtuple | ||||
| from django.db.models import Count | ||||
| from django.db.models import Q | ||||
| 
 | ||||
| from .conditions import ConditionController | ||||
| 
 | ||||
| from registrasion.models import conditions | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| 
 | ||||
| class FlagController(object): | ||||
| 
 | ||||
|     SINGLE = True | ||||
|     PLURAL = False | ||||
|     NONE = True | ||||
|     SOME = False | ||||
|     MESSAGE = { | ||||
|         NONE: { | ||||
|             SINGLE: | ||||
|                 "%(items)s is no longer available to you", | ||||
|             PLURAL: | ||||
|                 "%(items)s are no longer available to you", | ||||
|         }, | ||||
|         SOME: { | ||||
|             SINGLE: | ||||
|                 "Only %(remainder)d of the following item remains: %(items)s", | ||||
|             PLURAL: | ||||
|                 "Only %(remainder)d of the following items remain: %(items)s" | ||||
|         }, | ||||
|     } | ||||
| 
 | ||||
|     @classmethod | ||||
|     def test_flags( | ||||
|             cls, user, products=None, product_quantities=None): | ||||
|         ''' Evaluates all of the flag conditions on the given products. | ||||
| 
 | ||||
|         If `product_quantities` is supplied, the condition is only met if it | ||||
|         will permit the sum of the product quantities for all of the products | ||||
|         it covers. Otherwise, it will be met if at least one item can be | ||||
|         accepted. | ||||
| 
 | ||||
|         If all flag conditions pass, an empty list is returned, otherwise | ||||
|         a list is returned containing all of the products that are *not | ||||
|         enabled*. ''' | ||||
| 
 | ||||
|         print "GREPME: test_flags()" | ||||
| 
 | ||||
|         if products is not None and product_quantities is not None: | ||||
|             raise ValueError("Please specify only products or " | ||||
|                              "product_quantities") | ||||
|         elif products is None: | ||||
|             products = set(i[0] for i in product_quantities) | ||||
|             quantities = dict((product, quantity) | ||||
|                               for product, quantity in product_quantities) | ||||
|         elif product_quantities is None: | ||||
|             products = set(products) | ||||
|             quantities = {} | ||||
| 
 | ||||
|         if products: | ||||
|             # Simplify the query. | ||||
|             all_conditions = cls._filtered_flags(user, products) | ||||
|         else: | ||||
|             all_conditions = [] | ||||
| 
 | ||||
|         # All disable-if-false conditions on a product need to be met | ||||
|         do_not_disable = defaultdict(lambda: True) | ||||
|         # At least one enable-if-true condition on a product must be met | ||||
|         do_enable = defaultdict(lambda: False) | ||||
|         # (if either sort of condition is present) | ||||
| 
 | ||||
|         # Count the number of conditions for a product | ||||
|         dif_count = defaultdict(int) | ||||
|         eit_count = defaultdict(int) | ||||
| 
 | ||||
|         messages = {} | ||||
| 
 | ||||
|         for condition in all_conditions: | ||||
|             cond = ConditionController.for_condition(condition) | ||||
|             remainder = cond.user_quantity_remaining(user, filtered=True) | ||||
| 
 | ||||
|             # Get all products covered by this condition, and the products | ||||
|             # from the categories covered by this condition | ||||
| 
 | ||||
|             ids = [product.id for product in products] | ||||
|             all_products = inventory.Product.objects.filter(id__in=ids) | ||||
|             cond = ( | ||||
|                 Q(flagbase_set=condition) | | ||||
|                 Q(category__in=condition.categories.all()) | ||||
|             ) | ||||
| 
 | ||||
|             all_products = all_products.filter(cond) | ||||
|             all_products = all_products.select_related("category") | ||||
| 
 | ||||
|             if quantities: | ||||
|                 consumed = sum(quantities[i] for i in all_products) | ||||
|             else: | ||||
|                 consumed = 1 | ||||
|             met = consumed <= remainder | ||||
| 
 | ||||
|             if not met: | ||||
|                 items = ", ".join(str(product) for product in all_products) | ||||
|                 base = cls.MESSAGE[remainder == 0][len(all_products) == 1] | ||||
|                 message = base % {"items": items, "remainder": remainder} | ||||
| 
 | ||||
|             for product in all_products: | ||||
|                 if condition.is_disable_if_false: | ||||
|                     do_not_disable[product] &= met | ||||
|                     dif_count[product] += 1 | ||||
|                 else: | ||||
|                     do_enable[product] |= met | ||||
|                     eit_count[product] += 1 | ||||
| 
 | ||||
|                 if not met and product not in messages: | ||||
|                     messages[product] = message | ||||
| 
 | ||||
|         total_flags = FlagCounter.count() | ||||
| 
 | ||||
|         valid = {} | ||||
| 
 | ||||
|         # the problem is that now, not every condition falls into | ||||
|         # do_not_disable or do_enable ''' | ||||
|         # You should look into this, chris :) | ||||
| 
 | ||||
|         for product in products: | ||||
|             if quantities: | ||||
|                 if quantities[product] == 0: | ||||
|                     continue | ||||
| 
 | ||||
|             f = total_flags.get(product) | ||||
|             if f.dif > 0 and f.dif != dif_count[product]: | ||||
|                 do_not_disable[product] = False | ||||
|                 if product not in messages: | ||||
|                     messages[product] = "Some disable-if-false " \ | ||||
|                                         "conditions were not met" | ||||
|             if f.eit > 0 and product not in do_enable: | ||||
|                 do_enable[product] = False | ||||
|                 if product not in messages: | ||||
|                     messages[product] = "Some enable-if-true " \ | ||||
|                                         "conditions were not met" | ||||
| 
 | ||||
|         for product in itertools.chain(do_not_disable, do_enable): | ||||
|             f = total_flags.get(product) | ||||
|             if product in do_enable: | ||||
|                 # If there's an enable-if-true, we need need of those met too. | ||||
|                 # (do_not_disable will default to true otherwise) | ||||
|                 valid[product] = do_not_disable[product] and do_enable[product] | ||||
|             elif product in do_not_disable: | ||||
|                 # If there's a disable-if-false condition, all must be met | ||||
|                 valid[product] = do_not_disable[product] | ||||
| 
 | ||||
|         error_fields = [ | ||||
|             (product, messages[product]) | ||||
|             for product in valid if not valid[product] | ||||
|         ] | ||||
| 
 | ||||
|         return error_fields | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _filtered_flags(cls, user, products): | ||||
|         ''' | ||||
| 
 | ||||
|         Returns: | ||||
|             Sequence[flagbase]: All flags that passed the filter function. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         types = list(ConditionController._controllers()) | ||||
|         flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] | ||||
| 
 | ||||
|         # Get all flags for the products and categories. | ||||
|         prods = ( | ||||
|             product.flagbase_set.all() | ||||
|             for product in products | ||||
|         ) | ||||
|         cats = ( | ||||
|             category.flagbase_set.all() | ||||
|             for category in set(product.category for product in products) | ||||
|         ) | ||||
|         all_flags = reduce(operator.or_, itertools.chain(prods, cats)) | ||||
| 
 | ||||
|         all_subsets = [] | ||||
| 
 | ||||
|         for flagtype in flagtypes: | ||||
|             flags = flagtype.objects.filter(id__in=all_flags) | ||||
|             ctrl = ConditionController.for_type(flagtype) | ||||
|             flags = ctrl.pre_filter(flags, user) | ||||
|             all_subsets.append(flags) | ||||
| 
 | ||||
|         return itertools.chain(*all_subsets) | ||||
| 
 | ||||
| 
 | ||||
| ConditionAndRemainder = namedtuple( | ||||
|     "ConditionAndRemainder", | ||||
|     ( | ||||
|         "condition", | ||||
|         "remainder", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _FlagCounter = namedtuple( | ||||
|     "_FlagCounter", | ||||
|     ( | ||||
|         "products", | ||||
|         "categories", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _ConditionsCount = namedtuple( | ||||
|     "ConditionsCount", | ||||
|     ( | ||||
|         "dif", | ||||
|         "eit", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO: this should be cacheable. | ||||
| class FlagCounter(_FlagCounter): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def count(cls): | ||||
|         # Get the count of how many conditions should exist per product | ||||
|         flagbases = conditions.FlagBase.objects | ||||
| 
 | ||||
|         types = ( | ||||
|             conditions.FlagBase.ENABLE_IF_TRUE, | ||||
|             conditions.FlagBase.DISABLE_IF_FALSE, | ||||
|         ) | ||||
|         keys = ("eit", "dif") | ||||
|         flags = [ | ||||
|             flagbases.filter( | ||||
|                 condition=condition_type | ||||
|             ).values( | ||||
|                 'products', 'categories' | ||||
|             ).annotate( | ||||
|                 count=Count('id') | ||||
|             ) | ||||
|             for condition_type in types | ||||
|         ] | ||||
| 
 | ||||
|         cats = defaultdict(lambda: defaultdict(int)) | ||||
|         prod = defaultdict(lambda: defaultdict(int)) | ||||
| 
 | ||||
|         for key, flagcounts in zip(keys, flags): | ||||
|             for row in flagcounts: | ||||
|                 if row["products"] is not None: | ||||
|                     prod[row["products"]][key] = row["count"] | ||||
|                 if row["categories"] is not None: | ||||
|                     cats[row["categories"]][key] = row["count"] | ||||
| 
 | ||||
|         return cls(products=prod, categories=cats) | ||||
| 
 | ||||
|     def get(self, product): | ||||
|         p = self.products[product.id] | ||||
|         c = self.categories[product.category.id] | ||||
|         eit = p["eit"] + c["eit"] | ||||
|         dif = p["dif"] + c["dif"] | ||||
|         return _ConditionsCount(dif=dif, eit=eit) | ||||
|  | @ -29,6 +29,7 @@ class InvoiceController(ForId, object): | |||
|         If such an invoice does not exist, the cart is validated, and if valid, | ||||
|         an invoice is generated.''' | ||||
| 
 | ||||
|         cart.refresh_from_db() | ||||
|         try: | ||||
|             invoice = commerce.Invoice.objects.exclude( | ||||
|                 status=commerce.Invoice.STATUS_VOID, | ||||
|  | @ -74,6 +75,8 @@ class InvoiceController(ForId, object): | |||
|     def _generate(cls, cart): | ||||
|         ''' Generates an invoice for the given cart. ''' | ||||
| 
 | ||||
|         cart.refresh_from_db() | ||||
| 
 | ||||
|         issued = timezone.now() | ||||
|         reservation_limit = cart.reservation_duration + cart.time_last_updated | ||||
|         # Never generate a due time that is before the issue time | ||||
|  | @ -96,6 +99,10 @@ class InvoiceController(ForId, object): | |||
|         ) | ||||
| 
 | ||||
|         product_items = commerce.ProductItem.objects.filter(cart=cart) | ||||
|         product_items = product_items.select_related( | ||||
|             "product", | ||||
|             "product__category", | ||||
|         ) | ||||
| 
 | ||||
|         if len(product_items) == 0: | ||||
|             raise ValidationError("Your cart is empty.") | ||||
|  | @ -103,29 +110,41 @@ class InvoiceController(ForId, object): | |||
|         product_items = product_items.order_by( | ||||
|             "product__category__order", "product__order" | ||||
|         ) | ||||
| 
 | ||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||
|         discount_items = discount_items.select_related( | ||||
|             "discount", | ||||
|             "product", | ||||
|             "product__category", | ||||
|         ) | ||||
| 
 | ||||
|         line_items = [] | ||||
| 
 | ||||
|         invoice_value = Decimal() | ||||
|         for item in product_items: | ||||
|             product = item.product | ||||
|             line_item = commerce.LineItem.objects.create( | ||||
|             line_item = commerce.LineItem( | ||||
|                 invoice=invoice, | ||||
|                 description="%s - %s" % (product.category.name, product.name), | ||||
|                 quantity=item.quantity, | ||||
|                 price=product.price, | ||||
|                 product=product, | ||||
|             ) | ||||
|             line_items.append(line_item) | ||||
|             invoice_value += line_item.quantity * line_item.price | ||||
| 
 | ||||
|         for item in discount_items: | ||||
|             line_item = commerce.LineItem.objects.create( | ||||
|             line_item = commerce.LineItem( | ||||
|                 invoice=invoice, | ||||
|                 description=item.discount.description, | ||||
|                 quantity=item.quantity, | ||||
|                 price=cls.resolve_discount_value(item) * -1, | ||||
|                 product=item.product, | ||||
|             ) | ||||
|             line_items.append(line_item) | ||||
|             invoice_value += line_item.quantity * line_item.price | ||||
| 
 | ||||
|         commerce.LineItem.objects.bulk_create(line_items) | ||||
| 
 | ||||
|         invoice.value = invoice_value | ||||
| 
 | ||||
|         invoice.save() | ||||
|  | @ -251,6 +270,9 @@ class InvoiceController(ForId, object): | |||
|     def _invoice_matches_cart(self): | ||||
|         ''' Returns true if there is no cart, or if the revision of this | ||||
|         invoice matches the current revision of the cart. ''' | ||||
| 
 | ||||
|         self._refresh() | ||||
| 
 | ||||
|         cart = self.invoice.cart | ||||
|         if not cart: | ||||
|             return True | ||||
|  |  | |||
|  | @ -1,11 +1,16 @@ | |||
| import itertools | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import When | ||||
| from django.db.models import Value | ||||
| 
 | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import inventory | ||||
| 
 | ||||
| from category import CategoryController | ||||
| from conditions import ConditionController | ||||
| from .category import CategoryController | ||||
| from .flag import FlagController | ||||
| 
 | ||||
| 
 | ||||
| class ProductController(object): | ||||
|  | @ -16,9 +21,7 @@ class ProductController(object): | |||
|     @classmethod | ||||
|     def available_products(cls, user, category=None, products=None): | ||||
|         ''' Returns a list of all of the products that are available per | ||||
|         flag conditions from the given categories. | ||||
|         TODO: refactor so that all conditions are tested here and | ||||
|         can_add_with_flags calls this method. ''' | ||||
|         flag conditions from the given categories. ''' | ||||
|         if category is None and products is None: | ||||
|             raise ValueError("You must provide products or a category") | ||||
| 
 | ||||
|  | @ -31,22 +34,21 @@ class ProductController(object): | |||
|         if products is not None: | ||||
|             all_products = set(itertools.chain(all_products, products)) | ||||
| 
 | ||||
|         cat_quants = dict( | ||||
|             ( | ||||
|                 category, | ||||
|                 CategoryController(category).user_quantity_remaining(user), | ||||
|             ) | ||||
|             for category in set(product.category for product in all_products) | ||||
|         ) | ||||
|         categories = set(product.category for product in all_products) | ||||
|         r = CategoryController.attach_user_remainders(user, categories) | ||||
|         cat_quants = dict((c, c) for c in r) | ||||
| 
 | ||||
|         r = ProductController.attach_user_remainders(user, all_products) | ||||
|         prod_quants = dict((p, p) for p in r) | ||||
| 
 | ||||
|         passed_limits = set( | ||||
|             product | ||||
|             for product in all_products | ||||
|             if cat_quants[product.category] > 0 | ||||
|             if cls(product).user_quantity_remaining(user) > 0 | ||||
|             if cat_quants[product.category].remainder > 0 | ||||
|             if prod_quants[product].remainder > 0 | ||||
|         ) | ||||
| 
 | ||||
|         failed_and_messages = ConditionController.test_flags( | ||||
|         failed_and_messages = FlagController.test_flags( | ||||
|             user, products=passed_limits | ||||
|         ) | ||||
|         failed_conditions = set(i[0] for i in failed_and_messages) | ||||
|  | @ -56,26 +58,47 @@ class ProductController(object): | |||
| 
 | ||||
|         return out | ||||
| 
 | ||||
|     @classmethod | ||||
|     def attach_user_remainders(cls, user, products): | ||||
|         ''' | ||||
| 
 | ||||
|         Return: | ||||
|             queryset(inventory.Product): A queryset containing items from | ||||
|             ``product``, with an extra attribute -- remainder = the amount of | ||||
|             this item that is remaining. | ||||
|         ''' | ||||
| 
 | ||||
|         ids = [product.id for product in products] | ||||
|         products = inventory.Product.objects.filter(id__in=ids) | ||||
| 
 | ||||
|         cart_filter = ( | ||||
|             Q(productitem__cart__user=user) & | ||||
|             Q(productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||
|         ) | ||||
| 
 | ||||
|         quantity = When( | ||||
|             cart_filter, | ||||
|             then='productitem__quantity' | ||||
|         ) | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             quantity, | ||||
|             default=Value(0), | ||||
|         ) | ||||
| 
 | ||||
|         remainder = Case( | ||||
|             When(limit_per_user=None, then=Value(99999999)), | ||||
|             default=F('limit_per_user') - Sum(quantity_or_zero), | ||||
|         ) | ||||
| 
 | ||||
|         products = products.annotate(remainder=remainder) | ||||
| 
 | ||||
|         return products | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' Returns the quantity of this product that the user add in the | ||||
|         current cart. ''' | ||||
| 
 | ||||
|         prod_limit = self.product.limit_per_user | ||||
|         with_remainders = self.attach_user_remainders(user, [self.product]) | ||||
| 
 | ||||
|         if prod_limit is None: | ||||
|             # Don't need to run the remaining queries | ||||
|             return 999999  # We can do better | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter( | ||||
|             user=user, | ||||
|             status=commerce.Cart.STATUS_PAID, | ||||
|         ) | ||||
| 
 | ||||
|         items = commerce.ProductItem.objects.filter( | ||||
|             cart__in=carts, | ||||
|             product=self.product, | ||||
|         ) | ||||
| 
 | ||||
|         prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 | ||||
| 
 | ||||
|         return prod_limit - prod_count | ||||
|         return with_remainders[0].remainder | ||||
|  |  | |||
|  | @ -4,7 +4,11 @@ from registrasion.controllers.category import CategoryController | |||
| 
 | ||||
| from collections import namedtuple | ||||
| from django import template | ||||
| from django.db.models import Case | ||||
| from django.db.models import Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import When | ||||
| from django.db.models import Value | ||||
| 
 | ||||
| register = template.Library() | ||||
| 
 | ||||
|  | @ -99,20 +103,33 @@ def items_purchased(context, category=None): | |||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     all_items = commerce.ProductItem.objects.filter( | ||||
|         cart__user=context.request.user, | ||||
|         cart__status=commerce.Cart.STATUS_PAID, | ||||
|     ).select_related("product", "product__category") | ||||
|     in_cart = ( | ||||
|         Q(productitem__cart__user=context.request.user) & | ||||
|         Q(productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||
|     ) | ||||
| 
 | ||||
|     quantities_in_cart = When( | ||||
|         in_cart, | ||||
|         then="productitem__quantity", | ||||
|     ) | ||||
| 
 | ||||
|     quantities_or_zero = Case( | ||||
|         quantities_in_cart, | ||||
|         default=Value(0), | ||||
|     ) | ||||
| 
 | ||||
|     products = inventory.Product.objects | ||||
| 
 | ||||
|     if category: | ||||
|         all_items = all_items.filter(product__category=category) | ||||
|         products = products.filter(category=category) | ||||
| 
 | ||||
|     products = products.select_related("category") | ||||
|     products = products.annotate(quantity=Sum(quantities_or_zero)) | ||||
|     products = products.filter(quantity__gt=0) | ||||
| 
 | ||||
|     pq = all_items.values("product").annotate(quantity=Sum("quantity")).all() | ||||
|     products = inventory.Product.objects.all() | ||||
|     out = [] | ||||
|     for item in pq: | ||||
|         prod = products.get(pk=item["product"]) | ||||
|         out.append(ProductAndQuantity(prod, item["quantity"])) | ||||
|     for prod in products: | ||||
|         out.append(ProductAndQuantity(prod, prod.quantity)) | ||||
|     return out | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ class RegistrationCartTestCase(SetTimeMixin, TestCase): | |||
|         super(RegistrationCartTestCase, self).setUp() | ||||
| 
 | ||||
|     def tearDown(self): | ||||
|         if False: | ||||
|         if True: | ||||
|             # If you're seeing segfaults in tests, enable this. | ||||
|             call_command( | ||||
|                 'flush', | ||||
|  |  | |||
|  | @ -6,6 +6,8 @@ from django.core.exceptions import ValidationError | |||
| from controller_helpers import TestingCartController | ||||
| from test_cart import RegistrationCartTestCase | ||||
| 
 | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from registrasion.controllers.product import ProductController | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
|  | @ -135,6 +137,43 @@ class CeilingsTestCases(RegistrationCartTestCase): | |||
|         with self.assertRaises(ValidationError): | ||||
|             first_cart.validate_cart() | ||||
| 
 | ||||
|     def test_discount_ceiling_aggregates_products(self): | ||||
|         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||
|         # after second. | ||||
|         self.make_discount_ceiling( | ||||
|             "Multi-product limit discount ceiling", | ||||
|             limit=2, | ||||
|         ) | ||||
|         for i in xrange(2): | ||||
|             cart = TestingCartController.for_user(self.USER_1) | ||||
|             cart.add_to_cart(self.PROD_1, 1) | ||||
|             cart.next_cart() | ||||
| 
 | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_1], | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_flag_ceiling_aggregates_products(self): | ||||
|         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||
|         # after second. | ||||
|         self.make_ceiling("Multi-product limit ceiling", limit=2) | ||||
| 
 | ||||
|         for i in xrange(2): | ||||
|             cart = TestingCartController.for_user(self.USER_1) | ||||
|             cart.add_to_cart(self.PROD_1, 1) | ||||
|             cart.next_cart() | ||||
| 
 | ||||
|         products = ProductController.available_products( | ||||
|             self.USER_1, | ||||
|             products=[self.PROD_1], | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual(0, len(products)) | ||||
| 
 | ||||
|     def test_items_released_from_ceiling_by_refund(self): | ||||
|         self.make_ceiling("Limit ceiling", limit=1) | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ from decimal import Decimal | |||
| 
 | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| from registrasion.controllers import discount | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from controller_helpers import TestingCartController | ||||
| 
 | ||||
| from test_cart import RegistrationCartTestCase | ||||
|  | @ -243,22 +243,30 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|             # The discount is applied. | ||||
|             self.assertEqual(1, len(discount_items)) | ||||
| 
 | ||||
|     # Tests for the discount.available_discounts enumerator | ||||
|     # Tests for the DiscountController.available_discounts enumerator | ||||
|     def test_enumerate_no_discounts_for_no_input(self): | ||||
|         discounts = discount.available_discounts(self.USER_1, [], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_enumerate_no_discounts_if_condition_not_met(self): | ||||
|         self.add_discount_prod_1_includes_cat_2(quantity=1) | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3], | ||||
|         ) | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_category_discount_appears_once_if_met_twice(self): | ||||
|  | @ -267,7 +275,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [self.PROD_3], | ||||
|  | @ -280,7 +288,11 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(1, len(discounts)) | ||||
| 
 | ||||
|     def test_category_discount_appears_with_product(self): | ||||
|  | @ -289,7 +301,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3], | ||||
|  | @ -302,7 +314,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3, self.PROD_4] | ||||
|  | @ -315,7 +327,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_2], | ||||
|  | @ -328,7 +340,11 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
| 
 | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_1], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_1], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_discount_quantity_is_correct_before_first_purchase(self): | ||||
|  | @ -338,7 +354,11 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
|         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity | ||||
| 
 | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(2, discounts[0].quantity) | ||||
| 
 | ||||
|         cart.next_cart() | ||||
|  | @ -349,32 +369,63 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity | ||||
| 
 | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(1, discounts[0].quantity) | ||||
| 
 | ||||
|         cart.next_cart() | ||||
| 
 | ||||
|     def test_discount_is_gone_after_quantity_exhausted(self): | ||||
|         self.test_discount_quantity_is_correct_after_first_purchase() | ||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [self.CAT_2], | ||||
|             [], | ||||
|         ) | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_product_discount_enabled_twice_appears_twice(self): | ||||
|         self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3, self.PROD_4], | ||||
|         ) | ||||
|         self.assertEqual(2, len(discounts)) | ||||
| 
 | ||||
|     def test_product_discount_applied_on_different_invoices(self): | ||||
|         # quantity=1 means "quantity per product" | ||||
|         self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1) | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3, self.PROD_4], | ||||
|         ) | ||||
|         self.assertEqual(2, len(discounts)) | ||||
|         # adding one of PROD_3 should make it no longer an available discount. | ||||
|         cart.add_to_cart(self.PROD_3, 1) | ||||
|         cart.next_cart() | ||||
| 
 | ||||
|         # should still have (and only have) the discount for prod_4 | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_3, self.PROD_4], | ||||
|         ) | ||||
|         self.assertEqual(1, len(discounts)) | ||||
| 
 | ||||
|     def test_discounts_are_released_by_refunds(self): | ||||
|         self.add_discount_prod_1_includes_prod_2(quantity=2) | ||||
|         cart = TestingCartController.for_user(self.USER_1) | ||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_2], | ||||
|  | @ -388,7 +439,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
| 
 | ||||
|         cart.next_cart() | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_2], | ||||
|  | @ -398,7 +449,7 @@ class DiscountTestCase(RegistrationCartTestCase): | |||
|         cart.cart.status = commerce.Cart.STATUS_RELEASED | ||||
|         cart.cart.save() | ||||
| 
 | ||||
|         discounts = discount.available_discounts( | ||||
|         discounts = DiscountController.available_discounts( | ||||
|             self.USER_1, | ||||
|             [], | ||||
|             [self.PROD_2], | ||||
|  |  | |||
|  | @ -25,3 +25,33 @@ def all_arguments_optional(ntcls): | |||
|     ) | ||||
| 
 | ||||
|     return ntcls | ||||
| 
 | ||||
| 
 | ||||
| def lazy(function, *args, **kwargs): | ||||
|     ''' Produces a callable so that functions can be lazily evaluated in | ||||
|     templates. | ||||
| 
 | ||||
|     Arguments: | ||||
| 
 | ||||
|         function (callable): The function to call at evaluation time. | ||||
| 
 | ||||
|         args: Positional arguments, passed directly to ``function``. | ||||
| 
 | ||||
|         kwargs: Keyword arguments, passed directly to ``function``. | ||||
| 
 | ||||
|     Return: | ||||
| 
 | ||||
|         callable: A callable that will evaluate a call to ``function`` with | ||||
|             the specified arguments. | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     NOT_EVALUATED = object() | ||||
|     retval = [NOT_EVALUATED] | ||||
| 
 | ||||
|     def evaluate(): | ||||
|         if retval[0] is NOT_EVALUATED: | ||||
|             retval[0] = function(*args, **kwargs) | ||||
|         return retval[0] | ||||
| 
 | ||||
|     return evaluate | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ from registrasion import util | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import inventory | ||||
| from registrasion.models import people | ||||
| from registrasion.controllers import discount | ||||
| from registrasion.controllers.discount import DiscountController | ||||
| from registrasion.controllers.cart import CartController | ||||
| from registrasion.controllers.credit_note import CreditNoteController | ||||
| from registrasion.controllers.invoice import InvoiceController | ||||
|  | @ -181,33 +181,35 @@ def guided_registration(request): | |||
|             attendee.save() | ||||
|             return next_step | ||||
| 
 | ||||
|         for category in cats: | ||||
|             products = [ | ||||
|                 i for i in available_products | ||||
|                 if i.category == category | ||||
|             ] | ||||
|         with CartController.operations_batch(request.user): | ||||
|             for category in cats: | ||||
|                 products = [ | ||||
|                     i for i in available_products | ||||
|                     if i.category == category | ||||
|                 ] | ||||
| 
 | ||||
|             prefix = "category_" + str(category.id) | ||||
|             p = _handle_products(request, category, products, prefix) | ||||
|             products_form, discounts, products_handled = p | ||||
|                 prefix = "category_" + str(category.id) | ||||
|                 p = _handle_products(request, category, products, prefix) | ||||
|                 products_form, discounts, products_handled = p | ||||
| 
 | ||||
|             section = GuidedRegistrationSection( | ||||
|                 title=category.name, | ||||
|                 description=category.description, | ||||
|                 discounts=discounts, | ||||
|                 form=products_form, | ||||
|             ) | ||||
|                 section = GuidedRegistrationSection( | ||||
|                     title=category.name, | ||||
|                     description=category.description, | ||||
|                     discounts=discounts, | ||||
|                     form=products_form, | ||||
|                 ) | ||||
| 
 | ||||
|             if products: | ||||
|                 # This product category has items to show. | ||||
|                 sections.append(section) | ||||
|                 # Add this to the list of things to show if the form errors. | ||||
|                 request.session[SESSION_KEY].append(category.id) | ||||
|                 if products: | ||||
|                     # This product category has items to show. | ||||
|                     sections.append(section) | ||||
|                     # Add this to the list of things to show if the form | ||||
|                     # errors. | ||||
|                     request.session[SESSION_KEY].append(category.id) | ||||
| 
 | ||||
|                 if request.method == "POST" and not products_form.errors: | ||||
|                     # This is only saved if we pass each form with no errors, | ||||
|                     # and if the form actually has products. | ||||
|                     attendee.guided_categories_complete.add(category) | ||||
|                     if request.method == "POST" and not products_form.errors: | ||||
|                         # This is only saved if we pass each form with no | ||||
|                         # errors, and if the form actually has products. | ||||
|                         attendee.guided_categories_complete.add(category) | ||||
| 
 | ||||
|     if sections and request.method == "POST": | ||||
|         for section in sections: | ||||
|  | @ -427,7 +429,15 @@ def _handle_products(request, category, products, prefix): | |||
|                 ) | ||||
|     handled = False if products_form.errors else True | ||||
| 
 | ||||
|     discounts = discount.available_discounts(request.user, [], products) | ||||
|     # Making this a function to lazily evaluate when it's displayed | ||||
|     # in templates. | ||||
| 
 | ||||
|     discounts = util.lazy( | ||||
|         DiscountController.available_discounts, | ||||
|         request.user, | ||||
|         [], | ||||
|         products, | ||||
|     ) | ||||
| 
 | ||||
|     return products_form, discounts, handled | ||||
| 
 | ||||
|  | @ -435,14 +445,14 @@ def _handle_products(request, category, products, prefix): | |||
| def _set_quantities_from_products_form(products_form, current_cart): | ||||
| 
 | ||||
|     quantities = list(products_form.product_quantities()) | ||||
| 
 | ||||
|     id_to_quantity = dict(i[:2] for i in quantities) | ||||
|     pks = [i[0] for i in quantities] | ||||
|     products = inventory.Product.objects.filter( | ||||
|         id__in=pks, | ||||
|     ).select_related("category") | ||||
| 
 | ||||
|     product_quantities = [ | ||||
|         (products.get(pk=i[0]), i[1]) for i in quantities | ||||
|         (product, id_to_quantity[product.id]) for product in products | ||||
|     ] | ||||
|     field_names = dict( | ||||
|         (i[0][0], i[1][2]) for i in zip(product_quantities, quantities) | ||||
|  |  | |||
|  | @ -1,2 +1,2 @@ | |||
| [flake8] | ||||
| exclude = registrasion/migrations/*, build/*, docs/* | ||||
| exclude = registrasion/migrations/*, build/*, docs/*, dist/* | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Christopher Neugebauer
						Christopher Neugebauer