Rearchitected condition processing such that multiple conditions are processed by the database, in bulk. Closes #42.
This commit is contained in:
		
							parent
							
								
									05269c93cd
								
							
						
					
					
						commit
						3f1be0e14e
					
				
					 5 changed files with 471 additions and 137 deletions
				
			
		|  | @ -307,6 +307,8 @@ class CartController(object): | ||||||
|             self._append_errors(errors, ve) |             self._append_errors(errors, ve) | ||||||
| 
 | 
 | ||||||
|         # Validate the discounts |         # Validate the discounts | ||||||
|  |         # TODO: refactor in terms of available_discounts | ||||||
|  |         # why aren't we doing that here?! | ||||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) |         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||||
|         seen_discounts = set() |         seen_discounts = set() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,7 +4,12 @@ import operator | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| from collections import namedtuple | from collections import namedtuple | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import Count | ||||||
|  | from django.db.models import F, Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import Value | ||||||
|  | from django.db.models import When | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| 
 | 
 | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
|  | @ -12,6 +17,7 @@ from registrasion.models import conditions | ||||||
| from registrasion.models import inventory | from registrasion.models import inventory | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| ConditionAndRemainder = namedtuple( | ConditionAndRemainder = namedtuple( | ||||||
|     "ConditionAndRemainder", |     "ConditionAndRemainder", | ||||||
|     ( |     ( | ||||||
|  | @ -21,16 +27,77 @@ ConditionAndRemainder = namedtuple( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | _FlagCounter = namedtuple( | ||||||
|  |     "_FlagCounter", | ||||||
|  |     ( | ||||||
|  |         "products", | ||||||
|  |         "categories", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _ConditionsCount = namedtuple( | ||||||
|  |     "ConditionsCount", | ||||||
|  |     ( | ||||||
|  |         "dif", | ||||||
|  |         "eit", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FlagCounter(_FlagCounter): | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def count(cls): | ||||||
|  |         # Get the count of how many conditions should exist per product | ||||||
|  |         flagbases = conditions.FlagBase.objects | ||||||
|  | 
 | ||||||
|  |         types = ( | ||||||
|  |             conditions.FlagBase.ENABLE_IF_TRUE, | ||||||
|  |             conditions.FlagBase.DISABLE_IF_FALSE, | ||||||
|  |         ) | ||||||
|  |         keys = ("eit", "dif") | ||||||
|  |         flags = [ | ||||||
|  |             flagbases.filter( | ||||||
|  |                 condition=condition_type | ||||||
|  |             ).values( | ||||||
|  |                 'products', 'categories' | ||||||
|  |             ).annotate( | ||||||
|  |                 count=Count('id') | ||||||
|  |             ) | ||||||
|  |             for condition_type in types | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         cats = defaultdict(lambda: defaultdict(int)) | ||||||
|  |         prod = defaultdict(lambda: defaultdict(int)) | ||||||
|  | 
 | ||||||
|  |         for key, flagcounts in zip(keys, flags): | ||||||
|  |             for row in flagcounts: | ||||||
|  |                 if row["products"] is not None: | ||||||
|  |                     prod[row["products"]][key] = row["count"] | ||||||
|  |                 if row["categories"] is not None: | ||||||
|  |                     cats[row["categories"]][key] = row["count"] | ||||||
|  | 
 | ||||||
|  |         return cls(products=prod, categories=cats) | ||||||
|  | 
 | ||||||
|  |     def get(self, product): | ||||||
|  |         p = self.products[product.id] | ||||||
|  |         c = self.categories[product.category.id] | ||||||
|  |         eit = p["eit"] + c["eit"] | ||||||
|  |         dif = p["dif"] + c["dif"] | ||||||
|  |         return _ConditionsCount(dif=dif, eit=eit) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ConditionController(object): | class ConditionController(object): | ||||||
|     ''' Base class for testing conditions that activate Flag |     ''' Base class for testing conditions that activate Flag | ||||||
|     or Discount objects. ''' |     or Discount objects. ''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self): |     def __init__(self, condition): | ||||||
|         pass |         self.condition = condition | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def for_condition(condition): |     def _controllers(): | ||||||
|         CONTROLLERS = { |         return { | ||||||
|             conditions.CategoryFlag: CategoryConditionController, |             conditions.CategoryFlag: CategoryConditionController, | ||||||
|             conditions.IncludedProductDiscount: ProductConditionController, |             conditions.IncludedProductDiscount: ProductConditionController, | ||||||
|             conditions.ProductFlag: ProductConditionController, |             conditions.ProductFlag: ProductConditionController, | ||||||
|  | @ -42,8 +109,14 @@ class ConditionController(object): | ||||||
|             conditions.VoucherFlag: VoucherConditionController, |             conditions.VoucherFlag: VoucherConditionController, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def for_type(cls): | ||||||
|  |         return ConditionController._controllers()[cls] | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def for_condition(condition): | ||||||
|         try: |         try: | ||||||
|             return CONTROLLERS[type(condition)](condition) |             return ConditionController.for_type(type(condition))(condition) | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             return ConditionController() |             return ConditionController() | ||||||
| 
 | 
 | ||||||
|  | @ -91,20 +164,9 @@ class ConditionController(object): | ||||||
|             products = set(products) |             products = set(products) | ||||||
|             quantities = {} |             quantities = {} | ||||||
| 
 | 
 | ||||||
|         # Get the conditions covered by the products themselves |  | ||||||
|         prods = ( |  | ||||||
|             product.flagbase_set.select_subclasses() |  | ||||||
|             for product in products |  | ||||||
|         ) |  | ||||||
|         # Get the conditions covered by their categories |  | ||||||
|         cats = ( |  | ||||||
|             category.flagbase_set.select_subclasses() |  | ||||||
|             for category in set(product.category for product in products) |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         if products: |         if products: | ||||||
|             # Simplify the query. |             # Simplify the query. | ||||||
|             all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) |             all_conditions = cls._filtered_flags(user, products) | ||||||
|         else: |         else: | ||||||
|             all_conditions = [] |             all_conditions = [] | ||||||
| 
 | 
 | ||||||
|  | @ -114,11 +176,15 @@ class ConditionController(object): | ||||||
|         do_enable = defaultdict(lambda: False) |         do_enable = defaultdict(lambda: False) | ||||||
|         # (if either sort of condition is present) |         # (if either sort of condition is present) | ||||||
| 
 | 
 | ||||||
|  |         # Count the number of conditions for a product | ||||||
|  |         dif_count = defaultdict(int) | ||||||
|  |         eit_count = defaultdict(int) | ||||||
|  | 
 | ||||||
|         messages = {} |         messages = {} | ||||||
| 
 | 
 | ||||||
|         for condition in all_conditions: |         for condition in all_conditions: | ||||||
|             cond = cls.for_condition(condition) |             cond = cls.for_condition(condition) | ||||||
|             remainder = cond.user_quantity_remaining(user) |             remainder = cond.user_quantity_remaining(user, filtered=True) | ||||||
| 
 | 
 | ||||||
|             # Get all products covered by this condition, and the products |             # Get all products covered by this condition, and the products | ||||||
|             # from the categories covered by this condition |             # from the categories covered by this condition | ||||||
|  | @ -149,14 +215,41 @@ class ConditionController(object): | ||||||
|             for product in all_products: |             for product in all_products: | ||||||
|                 if condition.is_disable_if_false: |                 if condition.is_disable_if_false: | ||||||
|                     do_not_disable[product] &= met |                     do_not_disable[product] &= met | ||||||
|  |                     dif_count[product] += 1 | ||||||
|                 else: |                 else: | ||||||
|                     do_enable[product] |= met |                     do_enable[product] |= met | ||||||
|  |                     eit_count[product] += 1 | ||||||
| 
 | 
 | ||||||
|                 if not met and product not in messages: |                 if not met and product not in messages: | ||||||
|                     messages[product] = message |                     messages[product] = message | ||||||
| 
 | 
 | ||||||
|  |         total_flags = FlagCounter.count() | ||||||
|  | 
 | ||||||
|         valid = {} |         valid = {} | ||||||
|  | 
 | ||||||
|  |         # the problem is that now, not every condition falls into | ||||||
|  |         # do_not_disable or do_enable ''' | ||||||
|  |         # You should look into this, chris :) | ||||||
|  | 
 | ||||||
|  |         for product in products: | ||||||
|  |             if quantities: | ||||||
|  |                 if quantities[product] == 0: | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |             f = total_flags.get(product) | ||||||
|  |             if f.dif > 0 and f.dif != dif_count[product]: | ||||||
|  |                 do_not_disable[product] = False | ||||||
|  |                 if product not in messages: | ||||||
|  |                     messages[product] = "Some disable-if-false " \ | ||||||
|  |                                         "conditions were not met" | ||||||
|  |             if f.eit > 0 and product not in do_enable: | ||||||
|  |                 do_enable[product] = False | ||||||
|  |                 if product not in messages: | ||||||
|  |                     messages[product] = "Some enable-if-true " \ | ||||||
|  |                                         "conditions were not met" | ||||||
|  | 
 | ||||||
|         for product in itertools.chain(do_not_disable, do_enable): |         for product in itertools.chain(do_not_disable, do_enable): | ||||||
|  |             f = total_flags.get(product) | ||||||
|             if product in do_enable: |             if product in do_enable: | ||||||
|                 # If there's an enable-if-true, we need need of those met too. |                 # If there's an enable-if-true, we need need of those met too. | ||||||
|                 # (do_not_disable will default to true otherwise) |                 # (do_not_disable will default to true otherwise) | ||||||
|  | @ -172,7 +265,71 @@ class ConditionController(object): | ||||||
| 
 | 
 | ||||||
|         return error_fields |         return error_fields | ||||||
| 
 | 
 | ||||||
|     def user_quantity_remaining(self, user): |     @classmethod | ||||||
|  |     def _filtered_flags(cls, user, products): | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Sequence[flagbase]: All flags that passed the filter function. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         types = list(ConditionController._controllers()) | ||||||
|  |         flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] | ||||||
|  | 
 | ||||||
|  |         # Get all flags for the products and categories. | ||||||
|  |         prods = ( | ||||||
|  |             product.flagbase_set.all() | ||||||
|  |             for product in products | ||||||
|  |         ) | ||||||
|  |         cats = ( | ||||||
|  |             category.flagbase_set.all() | ||||||
|  |             for category in set(product.category for product in products) | ||||||
|  |         ) | ||||||
|  |         all_flags = reduce(operator.or_, itertools.chain(prods, cats)) | ||||||
|  | 
 | ||||||
|  |         all_subsets = [] | ||||||
|  | 
 | ||||||
|  |         for flagtype in flagtypes: | ||||||
|  |             flags = flagtype.objects.filter(id__in=all_flags) | ||||||
|  |             ctrl = ConditionController.for_type(flagtype) | ||||||
|  |             flags = ctrl.pre_filter(flags, user) | ||||||
|  |             all_subsets.append(flags) | ||||||
|  | 
 | ||||||
|  |         return itertools.chain(*all_subsets) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def pre_filter(cls, queryset, user): | ||||||
|  |         ''' Returns only the flag conditions that might be available for this | ||||||
|  |         user. It should hopefully reduce the number of queries that need to be | ||||||
|  |         executed to determine if a flag is met. | ||||||
|  | 
 | ||||||
|  |         If this filtration implements the same query as is_met, then you should | ||||||
|  |         be able to implement ``is_met()`` in terms of this. | ||||||
|  | 
 | ||||||
|  |         Arguments: | ||||||
|  | 
 | ||||||
|  |             queryset (Queryset[c]): The canditate conditions. | ||||||
|  | 
 | ||||||
|  |             user (User): The user for whom we're testing these conditions. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Queryset[c]: A subset of the conditions that pass the pre-filter | ||||||
|  |                 test for this user. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         # Default implementation does NOTHING. | ||||||
|  |         return queryset | ||||||
|  | 
 | ||||||
|  |     def passes_filter(self, user): | ||||||
|  |         ''' Returns true if the condition passes the filter ''' | ||||||
|  | 
 | ||||||
|  |         cls = type(self.condition) | ||||||
|  |         qs = cls.objects.filter(pk=self.condition.id) | ||||||
|  |         return self.condition in self.pre_filter(qs, user) | ||||||
|  | 
 | ||||||
|  |     def user_quantity_remaining(self, user, filtered=False): | ||||||
|         ''' Returns the number of items covered by this flag condition the |         ''' Returns the number of items covered by this flag condition the | ||||||
|         user can add to the current cart. This default implementation returns |         user can add to the current cart. This default implementation returns | ||||||
|         a big number if is_met() is true, otherwise 0. |         a big number if is_met() is true, otherwise 0. | ||||||
|  | @ -180,26 +337,37 @@ class ConditionController(object): | ||||||
|         Either this method, or is_met() must be overridden in subclasses. |         Either this method, or is_met() must be overridden in subclasses. | ||||||
|         ''' |         ''' | ||||||
| 
 | 
 | ||||||
|         return 99999999 if self.is_met(user) else 0 |         return _BIG_QUANTITY if self.is_met(user, filtered) else 0 | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |     def is_met(self, user, filtered=False): | ||||||
|         ''' Returns True if this flag condition is met, otherwise returns |         ''' Returns True if this flag condition is met, otherwise returns | ||||||
|         False. |         False. | ||||||
| 
 | 
 | ||||||
|         Either this method, or user_quantity_remaining() must be overridden |         Either this method, or user_quantity_remaining() must be overridden | ||||||
|         in subclasses. |         in subclasses. | ||||||
|  | 
 | ||||||
|  |         Arguments: | ||||||
|  | 
 | ||||||
|  |             user (User): The user for whom this test must be met. | ||||||
|  | 
 | ||||||
|  |             filter (bool): If true, this condition was part of a queryset | ||||||
|  |                 returned by pre_filter() for this user. | ||||||
|  | 
 | ||||||
|         ''' |         ''' | ||||||
|         return self.user_quantity_remaining(user) > 0 |         return self.user_quantity_remaining(user, filtered) > 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CategoryConditionController(ConditionController): | class IsMetByFilter(object): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     def is_met(self, user, filtered=False): | ||||||
|         self.condition = condition |         ''' Returns True if this flag condition is met, otherwise returns | ||||||
|  |         False. It determines if the condition is met by calling pre_filter | ||||||
|  |         with a queryset containing only self.condition. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         if filtered: | ||||||
|         ''' returns True if the user has a product from a category that invokes |             return True  # Why query again? | ||||||
|         this condition in one of their carts ''' | 
 | ||||||
|  |         return self.passes_filter(user) | ||||||
| 
 | 
 | ||||||
|         carts = commerce.Cart.objects.filter(user=user) |         carts = commerce.Cart.objects.filter(user=user) | ||||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) |         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | ||||||
|  | @ -212,112 +380,176 @@ class CategoryConditionController(ConditionController): | ||||||
|         ).count() |         ).count() | ||||||
|         return products_count > 0 |         return products_count > 0 | ||||||
| 
 | 
 | ||||||
|  | class RemainderSetByFilter(object): | ||||||
| 
 | 
 | ||||||
| class ProductConditionController(ConditionController): |     def user_quantity_remaining(self, user, filtered=True): | ||||||
|  |         ''' returns 0 if the date range is violated, otherwise, it will return | ||||||
|  |         the quantity remaining under the stock limit. | ||||||
|  | 
 | ||||||
|  |         The filter for this condition must add an annotation called "remainder" | ||||||
|  |         in order for this to work. | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         if filtered: | ||||||
|  |             if hasattr(self.condition, "remainder"): | ||||||
|  |                 return self.condition.remainder | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         # Mark self.condition with a remainder | ||||||
|  |         qs = type(self.condition).objects.filter(pk=self.condition.id) | ||||||
|  |         qs = self.pre_filter(qs, user) | ||||||
|  | 
 | ||||||
|  |         if len(qs) > 0: | ||||||
|  |             return qs[0].remainder | ||||||
|  |         else: | ||||||
|  |             return 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CategoryConditionController(IsMetByFilter, ConditionController): | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has a | ||||||
|  |         product from a category invoking that item's condition in one of their | ||||||
|  |         carts. ''' | ||||||
|  | 
 | ||||||
|  |         items = commerce.ProductItem.objects.filter(cart__user=user) | ||||||
|  |         items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) | ||||||
|  |         items = items.select_related("product", "product__category") | ||||||
|  |         categories = [item.product.category for item in items] | ||||||
|  | 
 | ||||||
|  |         return queryset.filter(enabling_category__in=categories) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ProductConditionController(IsMetByFilter, ConditionController): | ||||||
|     ''' Condition tests for ProductFlag and |     ''' Condition tests for ProductFlag and | ||||||
|     IncludedProductDiscount. ''' |     IncludedProductDiscount. ''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     @classmethod | ||||||
|         self.condition = condition |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has a | ||||||
|  |         product invoking that item's condition in one of their carts. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         items = commerce.ProductItem.objects.filter(cart__user=user) | ||||||
|         ''' returns True if the user has a product that invokes this |         items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) | ||||||
|         condition in one of their carts ''' |         items = items.select_related("product", "product__category") | ||||||
|  |         products = [item.product for item in items] | ||||||
| 
 | 
 | ||||||
|         carts = commerce.Cart.objects.filter(user=user) |         return queryset.filter(enabling_products__in=products) | ||||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) |  | ||||||
|         products_count = commerce.ProductItem.objects.filter( |  | ||||||
|             cart__in=carts, |  | ||||||
|             product__in=self.condition.enabling_products.all(), |  | ||||||
|         ).count() |  | ||||||
|         return products_count > 0 |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitConditionController(ConditionController): | class TimeOrStockLimitConditionController( | ||||||
|  |         RemainderSetByFilter, | ||||||
|  |         ConditionController, | ||||||
|  |     ): | ||||||
|     ''' Common condition tests for TimeOrStockLimit Flag and |     ''' Common condition tests for TimeOrStockLimit Flag and | ||||||
|     Discount.''' |     Discount.''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, ceiling): |     @classmethod | ||||||
|         self.ceiling = ceiling |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the date falls into | ||||||
|  |         any specified range, but not yet where the stock limit is not yet | ||||||
|  |         reached.''' | ||||||
| 
 | 
 | ||||||
|     def user_quantity_remaining(self, user): |  | ||||||
|         ''' returns 0 if the date range is violated, otherwise, it will return |  | ||||||
|         the quantity remaining under the stock limit. ''' |  | ||||||
| 
 |  | ||||||
|         # Test date range |  | ||||||
|         if not self._test_date_range(): |  | ||||||
|             return 0 |  | ||||||
| 
 |  | ||||||
|         return self._get_remaining_stock(user) |  | ||||||
| 
 |  | ||||||
|     def _test_date_range(self): |  | ||||||
|         now = timezone.now() |         now = timezone.now() | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.start_time is not None: |         # Keep items with no start time, or start time not yet met. | ||||||
|             if now < self.ceiling.start_time: |         queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) | ||||||
|                 return False |         queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.end_time is not None: |         # Filter out items that have been reserved beyond the limits | ||||||
|             if now > self.ceiling.end_time: |         quantity_or_zero = self._calculate_quantities(user) | ||||||
|                 return False |  | ||||||
| 
 | 
 | ||||||
|         return True |         remainder = Case( | ||||||
|  |             When(limit=None, then=Value(_BIG_QUANTITY)), | ||||||
|  |             default=F("limit") - Sum(quantity_or_zero), | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def _get_remaining_stock(self, user): |         queryset = queryset.annotate(remainder=remainder) | ||||||
|         ''' Returns the stock that remains under this ceiling, excluding the |         queryset = queryset.filter(remainder__gt=0) | ||||||
|         user's current cart. ''' |  | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.limit is None: |         return queryset | ||||||
|             return 99999999 |  | ||||||
| 
 | 
 | ||||||
|         # We care about all reserved carts, but not the user's current cart |     @classmethod | ||||||
|  |     def _relevant_carts(cls, user): | ||||||
|         reserved_carts = commerce.Cart.reserved_carts() |         reserved_carts = commerce.Cart.reserved_carts() | ||||||
|         reserved_carts = reserved_carts.exclude( |         reserved_carts = reserved_carts.exclude( | ||||||
|             user=user, |             user=user, | ||||||
|             status=commerce.Cart.STATUS_ACTIVE, |             status=commerce.Cart.STATUS_ACTIVE, | ||||||
|         ) |         ) | ||||||
| 
 |         return reserved_carts | ||||||
|         items = self._items() |  | ||||||
|         items = items.filter(cart__in=reserved_carts) |  | ||||||
|         count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 |  | ||||||
| 
 |  | ||||||
|         return self.ceiling.limit - count |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitFlagController( | class TimeOrStockLimitFlagController( | ||||||
|         TimeOrStockLimitConditionController): |         TimeOrStockLimitConditionController): | ||||||
| 
 | 
 | ||||||
|     def _items(self): |     @classmethod | ||||||
|         category_products = inventory.Product.objects.filter( |     def _calculate_quantities(cls, user): | ||||||
|             category__in=self.ceiling.categories.all(), |         reserved_carts = cls._relevant_carts(user) | ||||||
|         ) |  | ||||||
|         products = self.ceiling.products.all() | category_products |  | ||||||
| 
 | 
 | ||||||
|         product_items = commerce.ProductItem.objects.filter( |         # Calculate category lines | ||||||
|             product__in=products.all(), |         cat_items = F('categories__product__productitem__product__category') | ||||||
|  |         reserved_category_products = ( | ||||||
|  |             Q(categories=cat_items) & | ||||||
|  |             Q(categories__product__productitem__cart__in=reserved_carts) | ||||||
|         ) |         ) | ||||||
|         return product_items | 
 | ||||||
|  |         # Calculate product lines | ||||||
|  |         reserved_products = ( | ||||||
|  |             Q(products=F('products__productitem__product')) & | ||||||
|  |             Q(products__productitem__cart__in=reserved_carts) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         category_quantity_in_reserved_carts = When( | ||||||
|  |             reserved_category_products, | ||||||
|  |             then="categories__product__productitem__quantity", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         product_quantity_in_reserved_carts = When( | ||||||
|  |             reserved_products, | ||||||
|  |             then="products__productitem__quantity", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             category_quantity_in_reserved_carts, | ||||||
|  |             product_quantity_in_reserved_carts, | ||||||
|  |             default=Value(0), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return quantity_or_zero | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | ||||||
| 
 | 
 | ||||||
|     def _items(self): |     @classmethod | ||||||
|         discount_items = commerce.DiscountItem.objects.filter( |     def _calculate_quantities(cls, user): | ||||||
|             discount=self.ceiling, |         reserved_carts = cls._relevant_carts(user) | ||||||
|  | 
 | ||||||
|  |         quantity_in_reserved_carts = When( | ||||||
|  |             discountitem__cart__in=reserved_carts, | ||||||
|  |             then="discountitem__quantity" | ||||||
|         ) |         ) | ||||||
|         return discount_items | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             quantity_in_reserved_carts, | ||||||
|  |             default=Value(0) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return quantity_or_zero | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class VoucherConditionController(ConditionController): | class VoucherConditionController(IsMetByFilter, ConditionController): | ||||||
|     ''' Condition test for VoucherFlag and VoucherDiscount.''' |     ''' Condition test for VoucherFlag and VoucherDiscount.''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     @classmethod | ||||||
|         self.condition = condition |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has entered | ||||||
|  |         a voucher that invokes that item's condition in one of their carts. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         carts = commerce.Cart.objects.filter( | ||||||
|         ''' returns True if the user has the given voucher attached. ''' |  | ||||||
|         carts_count = commerce.Cart.objects.filter( |  | ||||||
|             user=user, |             user=user, | ||||||
|             vouchers=self.condition.voucher, |         ) | ||||||
|         ).count() |         vouchers = [cart.vouchers.all() for cart in carts] | ||||||
|         return carts_count > 0 | 
 | ||||||
|  |         return queryset.filter(voucher__in=itertools.chain(*vouchers)) | ||||||
|  |  | ||||||
|  | @ -4,7 +4,11 @@ from conditions import ConditionController | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import Value | ||||||
|  | from django.db.models import When | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DiscountAndQuantity(object): | class DiscountAndQuantity(object): | ||||||
|  | @ -43,6 +47,62 @@ def available_discounts(user, categories, products): | ||||||
|     and products. The discounts also list the available quantity for this user, |     and products. The discounts also list the available quantity for this user, | ||||||
|     not including products that are pending purchase. ''' |     not including products that are pending purchase. ''' | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     filtered_clauses = _filtered_discounts(user, categories, products) | ||||||
|  | 
 | ||||||
|  |     discounts = [] | ||||||
|  | 
 | ||||||
|  |     # Markers so that we don't need to evaluate given conditions more than once | ||||||
|  |     accepted_discounts = set() | ||||||
|  |     failed_discounts = set() | ||||||
|  | 
 | ||||||
|  |     for clause in filtered_clauses: | ||||||
|  |         discount = clause.discount | ||||||
|  |         cond = ConditionController.for_condition(discount) | ||||||
|  | 
 | ||||||
|  |         past_use_count = discount.past_use_count | ||||||
|  | 
 | ||||||
|  |         # TODO: add test case -- | ||||||
|  |         # discount covers 2x prod_1 and 1x prod_2 | ||||||
|  |         # add 1x prod_2 | ||||||
|  |         # add 1x prod_1 | ||||||
|  |         # checkout | ||||||
|  |         # discount should be available for prod_1 | ||||||
|  | 
 | ||||||
|  |         if past_use_count >= clause.quantity: | ||||||
|  |             # This clause has exceeded its use count | ||||||
|  |             pass | ||||||
|  |         elif discount not in failed_discounts: | ||||||
|  |             # This clause is still available | ||||||
|  |             is_accepted = discount in accepted_discounts | ||||||
|  |             if is_accepted or cond.is_met(user, filtered=True): | ||||||
|  |                 # This clause is valid for this user | ||||||
|  |                 discounts.append(DiscountAndQuantity( | ||||||
|  |                     discount=discount, | ||||||
|  |                     clause=clause, | ||||||
|  |                     quantity=clause.quantity - past_use_count, | ||||||
|  |                 )) | ||||||
|  |                 accepted_discounts.add(discount) | ||||||
|  |             else: | ||||||
|  |                 # This clause is not valid for this user | ||||||
|  |                 failed_discounts.add(discount) | ||||||
|  |     return discounts | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _filtered_discounts(user, categories, products): | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Sequence[discountbase]: All discounts that passed the filter function. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  |     types = list(ConditionController._controllers()) | ||||||
|  |     discounttypes = [ | ||||||
|  |         i for i in types if issubclass(i, conditions.DiscountBase) | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|     # discounts that match provided categories |     # discounts that match provided categories | ||||||
|     category_discounts = conditions.DiscountForCategory.objects.filter( |     category_discounts = conditions.DiscountForCategory.objects.filter( | ||||||
|         category__in=categories |         category__in=categories | ||||||
|  | @ -67,51 +127,56 @@ def available_discounts(user, categories, products): | ||||||
|         "category", |         "category", | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  |     valid_discounts = conditions.DiscountBase.objects.filter( | ||||||
|  |         Q(discountforproduct__in=product_discounts) | | ||||||
|  |         Q(discountforcategory__in=all_category_discounts) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     all_subsets = [] | ||||||
|  | 
 | ||||||
|  |     for discounttype in discounttypes: | ||||||
|  |         discounts = discounttype.objects.filter(id__in=valid_discounts) | ||||||
|  |         ctrl = ConditionController.for_type(discounttype) | ||||||
|  |         discounts = ctrl.pre_filter(discounts, user) | ||||||
|  |         discounts = _annotate_with_past_uses(discounts, user) | ||||||
|  |         all_subsets.append(discounts) | ||||||
|  | 
 | ||||||
|  |     filtered_discounts = list(itertools.chain(*all_subsets)) | ||||||
|  | 
 | ||||||
|  |     # Map from discount key to itself (contains annotations added by filter) | ||||||
|  |     from_filter = dict((i.id, i) for i in filtered_discounts) | ||||||
|  | 
 | ||||||
|     # The set of all potential discounts |     # The set of all potential discounts | ||||||
|     potential_discounts = set(itertools.chain( |     discount_clauses = set(itertools.chain( | ||||||
|         product_discounts, |         product_discounts.filter(discount__in=filtered_discounts), | ||||||
|         all_category_discounts, |         all_category_discounts.filter(discount__in=filtered_discounts), | ||||||
|     )) |     )) | ||||||
| 
 | 
 | ||||||
|     discounts = [] |     # Replace discounts with the filtered ones | ||||||
|  |     # These are the correct subclasses (saves query later on), and have | ||||||
|  |     # correct annotations from filters if necessary. | ||||||
|  |     for clause in discount_clauses: | ||||||
|  |         clause.discount = from_filter[clause.discount.id] | ||||||
| 
 | 
 | ||||||
|     # Markers so that we don't need to evaluate given conditions more than once |     return discount_clauses | ||||||
|     accepted_discounts = set() |  | ||||||
|     failed_discounts = set() |  | ||||||
| 
 | 
 | ||||||
|     for discount in potential_discounts: |  | ||||||
|         real_discount = conditions.DiscountBase.objects.get_subclass( |  | ||||||
|             pk=discount.discount.pk, |  | ||||||
|         ) |  | ||||||
|         cond = ConditionController.for_condition(real_discount) |  | ||||||
| 
 | 
 | ||||||
|         # Count the past uses of the given discount item. | def _annotate_with_past_uses(queryset, user): | ||||||
|         # If this user has exceeded the limit for the clause, this clause |     ''' Annotates the queryset with a usage count for that discount by the | ||||||
|         # is not available any more. |     given user. ''' | ||||||
|         past_uses = commerce.DiscountItem.objects.filter( |  | ||||||
|             cart__user=user, |  | ||||||
|             cart__status=commerce.Cart.STATUS_PAID,  # Only past carts count |  | ||||||
|             discount=real_discount, |  | ||||||
|         ) |  | ||||||
|         agg = past_uses.aggregate(Sum("quantity")) |  | ||||||
|         past_use_count = agg["quantity__sum"] |  | ||||||
|         if past_use_count is None: |  | ||||||
|             past_use_count = 0 |  | ||||||
| 
 | 
 | ||||||
|         if past_use_count >= discount.quantity: |     past_use_quantity = When( | ||||||
|             # This clause has exceeded its use count |         ( | ||||||
|             pass |             Q(discountitem__cart__user=user) & | ||||||
|         elif real_discount not in failed_discounts: |             Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) | ||||||
|             # This clause is still available |         ), | ||||||
|             if real_discount in accepted_discounts or cond.is_met(user): |         then="discountitem__quantity", | ||||||
|                 # This clause is valid for this user |     ) | ||||||
|                 discounts.append(DiscountAndQuantity( | 
 | ||||||
|                     discount=real_discount, |     past_use_quantity_or_zero = Case( | ||||||
|                     clause=discount, |         past_use_quantity, | ||||||
|                     quantity=discount.quantity - past_use_count, |         default=Value(0), | ||||||
|                 )) |     ) | ||||||
|                 accepted_discounts.add(real_discount) | 
 | ||||||
|             else: |     queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) | ||||||
|                 # This clause is not valid for this user |     return queryset | ||||||
|                 failed_discounts.add(real_discount) |  | ||||||
|     return discounts |  | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ from django.core.exceptions import ValidationError | ||||||
| from controller_helpers import TestingCartController | from controller_helpers import TestingCartController | ||||||
| from test_cart import RegistrationCartTestCase | from test_cart import RegistrationCartTestCase | ||||||
| 
 | 
 | ||||||
|  | from registrasion.controllers.discount import available_discounts | ||||||
|  | from registrasion.controllers.product import ProductController | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| 
 | 
 | ||||||
|  | @ -135,6 +137,39 @@ class CeilingsTestCases(RegistrationCartTestCase): | ||||||
|         with self.assertRaises(ValidationError): |         with self.assertRaises(ValidationError): | ||||||
|             first_cart.validate_cart() |             first_cart.validate_cart() | ||||||
| 
 | 
 | ||||||
|  |     def test_discount_ceiling_aggregates_products(self): | ||||||
|  |         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||||
|  |         # after second. | ||||||
|  |         self.make_discount_ceiling( | ||||||
|  |             "Multi-product limit discount ceiling", | ||||||
|  |             limit=2, | ||||||
|  |         ) | ||||||
|  |         for i in xrange(2): | ||||||
|  |             cart = TestingCartController.for_user(self.USER_1) | ||||||
|  |             cart.add_to_cart(self.PROD_1, 1) | ||||||
|  |             cart.next_cart() | ||||||
|  | 
 | ||||||
|  |         discounts = available_discounts(self.USER_1, [], [self.PROD_1]) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(0, len(discounts)) | ||||||
|  | 
 | ||||||
|  |     def test_flag_ceiling_aggregates_products(self): | ||||||
|  |         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||||
|  |         # after second. | ||||||
|  |         self.make_ceiling("Multi-product limit ceiling", limit=2) | ||||||
|  | 
 | ||||||
|  |         for i in xrange(2): | ||||||
|  |             cart = TestingCartController.for_user(self.USER_1) | ||||||
|  |             cart.add_to_cart(self.PROD_1, 1) | ||||||
|  |             cart.next_cart() | ||||||
|  | 
 | ||||||
|  |         products = ProductController.available_products( | ||||||
|  |             self.USER_1, | ||||||
|  |             products=[self.PROD_1], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(0, len(products)) | ||||||
|  | 
 | ||||||
|     def test_items_released_from_ceiling_by_refund(self): |     def test_items_released_from_ceiling_by_refund(self): | ||||||
|         self.make_ceiling("Limit ceiling", limit=1) |         self.make_ceiling("Limit ceiling", limit=1) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,2 +1,2 @@ | ||||||
| [flake8] | [flake8] | ||||||
| exclude = registrasion/migrations/*, build/*, docs/* | exclude = registrasion/migrations/*, build/*, docs/*, dist/* | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Christopher Neugebauer
						Christopher Neugebauer