Rearchitected condition processing such that multiple conditions are processed by the database, in bulk. Closes #42.
This commit is contained in:
		
							parent
							
								
									05269c93cd
								
							
						
					
					
						commit
						3f1be0e14e
					
				
					 5 changed files with 471 additions and 137 deletions
				
			
		|  | @ -307,6 +307,8 @@ class CartController(object): | |||
|             self._append_errors(errors, ve) | ||||
| 
 | ||||
|         # Validate the discounts | ||||
|         # TODO: refactor in terms of available_discounts | ||||
|         # why aren't we doing that here?! | ||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||
|         seen_discounts = set() | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,7 +4,12 @@ import operator | |||
| from collections import defaultdict | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import Count | ||||
| from django.db.models import F, Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import Value | ||||
| from django.db.models import When | ||||
| from django.utils import timezone | ||||
| 
 | ||||
| from registrasion.models import commerce | ||||
|  | @ -12,6 +17,7 @@ from registrasion.models import conditions | |||
| from registrasion.models import inventory | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| ConditionAndRemainder = namedtuple( | ||||
|     "ConditionAndRemainder", | ||||
|     ( | ||||
|  | @ -21,16 +27,77 @@ ConditionAndRemainder = namedtuple( | |||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _FlagCounter = namedtuple( | ||||
|     "_FlagCounter", | ||||
|     ( | ||||
|         "products", | ||||
|         "categories", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| _ConditionsCount = namedtuple( | ||||
|     "ConditionsCount", | ||||
|     ( | ||||
|         "dif", | ||||
|         "eit", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class FlagCounter(_FlagCounter): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def count(cls): | ||||
|         # Get the count of how many conditions should exist per product | ||||
|         flagbases = conditions.FlagBase.objects | ||||
| 
 | ||||
|         types = ( | ||||
|             conditions.FlagBase.ENABLE_IF_TRUE, | ||||
|             conditions.FlagBase.DISABLE_IF_FALSE, | ||||
|         ) | ||||
|         keys = ("eit", "dif") | ||||
|         flags = [ | ||||
|             flagbases.filter( | ||||
|                 condition=condition_type | ||||
|             ).values( | ||||
|                 'products', 'categories' | ||||
|             ).annotate( | ||||
|                 count=Count('id') | ||||
|             ) | ||||
|             for condition_type in types | ||||
|         ] | ||||
| 
 | ||||
|         cats = defaultdict(lambda: defaultdict(int)) | ||||
|         prod = defaultdict(lambda: defaultdict(int)) | ||||
| 
 | ||||
|         for key, flagcounts in zip(keys, flags): | ||||
|             for row in flagcounts: | ||||
|                 if row["products"] is not None: | ||||
|                     prod[row["products"]][key] = row["count"] | ||||
|                 if row["categories"] is not None: | ||||
|                     cats[row["categories"]][key] = row["count"] | ||||
| 
 | ||||
|         return cls(products=prod, categories=cats) | ||||
| 
 | ||||
|     def get(self, product): | ||||
|         p = self.products[product.id] | ||||
|         c = self.categories[product.category.id] | ||||
|         eit = p["eit"] + c["eit"] | ||||
|         dif = p["dif"] + c["dif"] | ||||
|         return _ConditionsCount(dif=dif, eit=eit) | ||||
| 
 | ||||
| 
 | ||||
| class ConditionController(object): | ||||
|     ''' Base class for testing conditions that activate Flag | ||||
|     or Discount objects. ''' | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         pass | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_condition(condition): | ||||
|         CONTROLLERS = { | ||||
|     def _controllers(): | ||||
|         return { | ||||
|             conditions.CategoryFlag: CategoryConditionController, | ||||
|             conditions.IncludedProductDiscount: ProductConditionController, | ||||
|             conditions.ProductFlag: ProductConditionController, | ||||
|  | @ -42,8 +109,14 @@ class ConditionController(object): | |||
|             conditions.VoucherFlag: VoucherConditionController, | ||||
|         } | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_type(cls): | ||||
|         return ConditionController._controllers()[cls] | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def for_condition(condition): | ||||
|         try: | ||||
|             return CONTROLLERS[type(condition)](condition) | ||||
|             return ConditionController.for_type(type(condition))(condition) | ||||
|         except KeyError: | ||||
|             return ConditionController() | ||||
| 
 | ||||
|  | @ -91,20 +164,9 @@ class ConditionController(object): | |||
|             products = set(products) | ||||
|             quantities = {} | ||||
| 
 | ||||
|         # Get the conditions covered by the products themselves | ||||
|         prods = ( | ||||
|             product.flagbase_set.select_subclasses() | ||||
|             for product in products | ||||
|         ) | ||||
|         # Get the conditions covered by their categories | ||||
|         cats = ( | ||||
|             category.flagbase_set.select_subclasses() | ||||
|             for category in set(product.category for product in products) | ||||
|         ) | ||||
| 
 | ||||
|         if products: | ||||
|             # Simplify the query. | ||||
|             all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) | ||||
|             all_conditions = cls._filtered_flags(user, products) | ||||
|         else: | ||||
|             all_conditions = [] | ||||
| 
 | ||||
|  | @ -114,11 +176,15 @@ class ConditionController(object): | |||
|         do_enable = defaultdict(lambda: False) | ||||
|         # (if either sort of condition is present) | ||||
| 
 | ||||
|         # Count the number of conditions for a product | ||||
|         dif_count = defaultdict(int) | ||||
|         eit_count = defaultdict(int) | ||||
| 
 | ||||
|         messages = {} | ||||
| 
 | ||||
|         for condition in all_conditions: | ||||
|             cond = cls.for_condition(condition) | ||||
|             remainder = cond.user_quantity_remaining(user) | ||||
|             remainder = cond.user_quantity_remaining(user, filtered=True) | ||||
| 
 | ||||
|             # Get all products covered by this condition, and the products | ||||
|             # from the categories covered by this condition | ||||
|  | @ -149,14 +215,41 @@ class ConditionController(object): | |||
|             for product in all_products: | ||||
|                 if condition.is_disable_if_false: | ||||
|                     do_not_disable[product] &= met | ||||
|                     dif_count[product] += 1 | ||||
|                 else: | ||||
|                     do_enable[product] |= met | ||||
|                     eit_count[product] += 1 | ||||
| 
 | ||||
|                 if not met and product not in messages: | ||||
|                     messages[product] = message | ||||
| 
 | ||||
|         total_flags = FlagCounter.count() | ||||
| 
 | ||||
|         valid = {} | ||||
| 
 | ||||
|         # the problem is that now, not every condition falls into | ||||
|         # do_not_disable or do_enable ''' | ||||
|         # You should look into this, chris :) | ||||
| 
 | ||||
|         for product in products: | ||||
|             if quantities: | ||||
|                 if quantities[product] == 0: | ||||
|                     continue | ||||
| 
 | ||||
|             f = total_flags.get(product) | ||||
|             if f.dif > 0 and f.dif != dif_count[product]: | ||||
|                 do_not_disable[product] = False | ||||
|                 if product not in messages: | ||||
|                     messages[product] = "Some disable-if-false " \ | ||||
|                                         "conditions were not met" | ||||
|             if f.eit > 0 and product not in do_enable: | ||||
|                 do_enable[product] = False | ||||
|                 if product not in messages: | ||||
|                     messages[product] = "Some enable-if-true " \ | ||||
|                                         "conditions were not met" | ||||
| 
 | ||||
|         for product in itertools.chain(do_not_disable, do_enable): | ||||
|             f = total_flags.get(product) | ||||
|             if product in do_enable: | ||||
|                 # If there's an enable-if-true, we need need of those met too. | ||||
|                 # (do_not_disable will default to true otherwise) | ||||
|  | @ -172,7 +265,71 @@ class ConditionController(object): | |||
| 
 | ||||
|         return error_fields | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|     @classmethod | ||||
|     def _filtered_flags(cls, user, products): | ||||
|         ''' | ||||
| 
 | ||||
|         Returns: | ||||
|             Sequence[flagbase]: All flags that passed the filter function. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         types = list(ConditionController._controllers()) | ||||
|         flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] | ||||
| 
 | ||||
|         # Get all flags for the products and categories. | ||||
|         prods = ( | ||||
|             product.flagbase_set.all() | ||||
|             for product in products | ||||
|         ) | ||||
|         cats = ( | ||||
|             category.flagbase_set.all() | ||||
|             for category in set(product.category for product in products) | ||||
|         ) | ||||
|         all_flags = reduce(operator.or_, itertools.chain(prods, cats)) | ||||
| 
 | ||||
|         all_subsets = [] | ||||
| 
 | ||||
|         for flagtype in flagtypes: | ||||
|             flags = flagtype.objects.filter(id__in=all_flags) | ||||
|             ctrl = ConditionController.for_type(flagtype) | ||||
|             flags = ctrl.pre_filter(flags, user) | ||||
|             all_subsets.append(flags) | ||||
| 
 | ||||
|         return itertools.chain(*all_subsets) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pre_filter(cls, queryset, user): | ||||
|         ''' Returns only the flag conditions that might be available for this | ||||
|         user. It should hopefully reduce the number of queries that need to be | ||||
|         executed to determine if a flag is met. | ||||
| 
 | ||||
|         If this filtration implements the same query as is_met, then you should | ||||
|         be able to implement ``is_met()`` in terms of this. | ||||
| 
 | ||||
|         Arguments: | ||||
| 
 | ||||
|             queryset (Queryset[c]): The canditate conditions. | ||||
| 
 | ||||
|             user (User): The user for whom we're testing these conditions. | ||||
| 
 | ||||
|         Returns: | ||||
|             Queryset[c]: A subset of the conditions that pass the pre-filter | ||||
|                 test for this user. | ||||
| 
 | ||||
|         ''' | ||||
| 
 | ||||
|         # Default implementation does NOTHING. | ||||
|         return queryset | ||||
| 
 | ||||
|     def passes_filter(self, user): | ||||
|         ''' Returns true if the condition passes the filter ''' | ||||
| 
 | ||||
|         cls = type(self.condition) | ||||
|         qs = cls.objects.filter(pk=self.condition.id) | ||||
|         return self.condition in self.pre_filter(qs, user) | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user, filtered=False): | ||||
|         ''' Returns the number of items covered by this flag condition the | ||||
|         user can add to the current cart. This default implementation returns | ||||
|         a big number if is_met() is true, otherwise 0. | ||||
|  | @ -180,26 +337,37 @@ class ConditionController(object): | |||
|         Either this method, or is_met() must be overridden in subclasses. | ||||
|         ''' | ||||
| 
 | ||||
|         return 99999999 if self.is_met(user) else 0 | ||||
|         return _BIG_QUANTITY if self.is_met(user, filtered) else 0 | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|     def is_met(self, user, filtered=False): | ||||
|         ''' Returns True if this flag condition is met, otherwise returns | ||||
|         False. | ||||
| 
 | ||||
|         Either this method, or user_quantity_remaining() must be overridden | ||||
|         in subclasses. | ||||
| 
 | ||||
|         Arguments: | ||||
| 
 | ||||
|             user (User): The user for whom this test must be met. | ||||
| 
 | ||||
|             filter (bool): If true, this condition was part of a queryset | ||||
|                 returned by pre_filter() for this user. | ||||
| 
 | ||||
|         ''' | ||||
|         return self.user_quantity_remaining(user) > 0 | ||||
|         return self.user_quantity_remaining(user, filtered) > 0 | ||||
| 
 | ||||
| 
 | ||||
| class CategoryConditionController(ConditionController): | ||||
| class IsMetByFilter(object): | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     def is_met(self, user, filtered=False): | ||||
|         ''' Returns True if this flag condition is met, otherwise returns | ||||
|         False. It determines if the condition is met by calling pre_filter | ||||
|         with a queryset containing only self.condition. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has a product from a category that invokes | ||||
|         this condition in one of their carts ''' | ||||
|         if filtered: | ||||
|             return True  # Why query again? | ||||
| 
 | ||||
|         return self.passes_filter(user) | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter(user=user) | ||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | ||||
|  | @ -212,112 +380,176 @@ class CategoryConditionController(ConditionController): | |||
|         ).count() | ||||
|         return products_count > 0 | ||||
| 
 | ||||
| class RemainderSetByFilter(object): | ||||
| 
 | ||||
| class ProductConditionController(ConditionController): | ||||
|     def user_quantity_remaining(self, user, filtered=True): | ||||
|         ''' returns 0 if the date range is violated, otherwise, it will return | ||||
|         the quantity remaining under the stock limit. | ||||
| 
 | ||||
|         The filter for this condition must add an annotation called "remainder" | ||||
|         in order for this to work. | ||||
|         ''' | ||||
| 
 | ||||
|         if filtered: | ||||
|             if hasattr(self.condition, "remainder"): | ||||
|                 return self.condition.remainder | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         # Mark self.condition with a remainder | ||||
|         qs = type(self.condition).objects.filter(pk=self.condition.id) | ||||
|         qs = self.pre_filter(qs, user) | ||||
| 
 | ||||
|         if len(qs) > 0: | ||||
|             return qs[0].remainder | ||||
|         else: | ||||
|             return 0 | ||||
| 
 | ||||
| 
 | ||||
| class CategoryConditionController(IsMetByFilter, ConditionController): | ||||
| 
 | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has a | ||||
|         product from a category invoking that item's condition in one of their | ||||
|         carts. ''' | ||||
| 
 | ||||
|         items = commerce.ProductItem.objects.filter(cart__user=user) | ||||
|         items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) | ||||
|         items = items.select_related("product", "product__category") | ||||
|         categories = [item.product.category for item in items] | ||||
| 
 | ||||
|         return queryset.filter(enabling_category__in=categories) | ||||
| 
 | ||||
| 
 | ||||
| class ProductConditionController(IsMetByFilter, ConditionController): | ||||
|     ''' Condition tests for ProductFlag and | ||||
|     IncludedProductDiscount. ''' | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has a | ||||
|         product invoking that item's condition in one of their carts. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has a product that invokes this | ||||
|         condition in one of their carts ''' | ||||
|         items = commerce.ProductItem.objects.filter(cart__user=user) | ||||
|         items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) | ||||
|         items = items.select_related("product", "product__category") | ||||
|         products = [item.product for item in items] | ||||
| 
 | ||||
|         carts = commerce.Cart.objects.filter(user=user) | ||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | ||||
|         products_count = commerce.ProductItem.objects.filter( | ||||
|             cart__in=carts, | ||||
|             product__in=self.condition.enabling_products.all(), | ||||
|         ).count() | ||||
|         return products_count > 0 | ||||
|         return queryset.filter(enabling_products__in=products) | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitConditionController(ConditionController): | ||||
| class TimeOrStockLimitConditionController( | ||||
|         RemainderSetByFilter, | ||||
|         ConditionController, | ||||
|     ): | ||||
|     ''' Common condition tests for TimeOrStockLimit Flag and | ||||
|     Discount.''' | ||||
| 
 | ||||
|     def __init__(self, ceiling): | ||||
|         self.ceiling = ceiling | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the date falls into | ||||
|         any specified range, but not yet where the stock limit is not yet | ||||
|         reached.''' | ||||
| 
 | ||||
|     def user_quantity_remaining(self, user): | ||||
|         ''' returns 0 if the date range is violated, otherwise, it will return | ||||
|         the quantity remaining under the stock limit. ''' | ||||
| 
 | ||||
|         # Test date range | ||||
|         if not self._test_date_range(): | ||||
|             return 0 | ||||
| 
 | ||||
|         return self._get_remaining_stock(user) | ||||
| 
 | ||||
|     def _test_date_range(self): | ||||
|         now = timezone.now() | ||||
| 
 | ||||
|         if self.ceiling.start_time is not None: | ||||
|             if now < self.ceiling.start_time: | ||||
|                 return False | ||||
|         # Keep items with no start time, or start time not yet met. | ||||
|         queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) | ||||
|         queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) | ||||
| 
 | ||||
|         if self.ceiling.end_time is not None: | ||||
|             if now > self.ceiling.end_time: | ||||
|                 return False | ||||
|         # Filter out items that have been reserved beyond the limits | ||||
|         quantity_or_zero = self._calculate_quantities(user) | ||||
| 
 | ||||
|         return True | ||||
|         remainder = Case( | ||||
|             When(limit=None, then=Value(_BIG_QUANTITY)), | ||||
|             default=F("limit") - Sum(quantity_or_zero), | ||||
|         ) | ||||
| 
 | ||||
|     def _get_remaining_stock(self, user): | ||||
|         ''' Returns the stock that remains under this ceiling, excluding the | ||||
|         user's current cart. ''' | ||||
|         queryset = queryset.annotate(remainder=remainder) | ||||
|         queryset = queryset.filter(remainder__gt=0) | ||||
| 
 | ||||
|         if self.ceiling.limit is None: | ||||
|             return 99999999 | ||||
|         return queryset | ||||
| 
 | ||||
|         # We care about all reserved carts, but not the user's current cart | ||||
|     @classmethod | ||||
|     def _relevant_carts(cls, user): | ||||
|         reserved_carts = commerce.Cart.reserved_carts() | ||||
|         reserved_carts = reserved_carts.exclude( | ||||
|             user=user, | ||||
|             status=commerce.Cart.STATUS_ACTIVE, | ||||
|         ) | ||||
| 
 | ||||
|         items = self._items() | ||||
|         items = items.filter(cart__in=reserved_carts) | ||||
|         count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 | ||||
| 
 | ||||
|         return self.ceiling.limit - count | ||||
|         return reserved_carts | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitFlagController( | ||||
|         TimeOrStockLimitConditionController): | ||||
| 
 | ||||
|     def _items(self): | ||||
|         category_products = inventory.Product.objects.filter( | ||||
|             category__in=self.ceiling.categories.all(), | ||||
|         ) | ||||
|         products = self.ceiling.products.all() | category_products | ||||
|     @classmethod | ||||
|     def _calculate_quantities(cls, user): | ||||
|         reserved_carts = cls._relevant_carts(user) | ||||
| 
 | ||||
|         product_items = commerce.ProductItem.objects.filter( | ||||
|             product__in=products.all(), | ||||
|         # Calculate category lines | ||||
|         cat_items = F('categories__product__productitem__product__category') | ||||
|         reserved_category_products = ( | ||||
|             Q(categories=cat_items) & | ||||
|             Q(categories__product__productitem__cart__in=reserved_carts) | ||||
|         ) | ||||
|         return product_items | ||||
| 
 | ||||
|         # Calculate product lines | ||||
|         reserved_products = ( | ||||
|             Q(products=F('products__productitem__product')) & | ||||
|             Q(products__productitem__cart__in=reserved_carts) | ||||
|         ) | ||||
| 
 | ||||
|         category_quantity_in_reserved_carts = When( | ||||
|             reserved_category_products, | ||||
|             then="categories__product__productitem__quantity", | ||||
|         ) | ||||
| 
 | ||||
|         product_quantity_in_reserved_carts = When( | ||||
|             reserved_products, | ||||
|             then="products__productitem__quantity", | ||||
|         ) | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             category_quantity_in_reserved_carts, | ||||
|             product_quantity_in_reserved_carts, | ||||
|             default=Value(0), | ||||
|         ) | ||||
| 
 | ||||
|         return quantity_or_zero | ||||
| 
 | ||||
| 
 | ||||
| class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | ||||
| 
 | ||||
|     def _items(self): | ||||
|         discount_items = commerce.DiscountItem.objects.filter( | ||||
|             discount=self.ceiling, | ||||
|     @classmethod | ||||
|     def _calculate_quantities(cls, user): | ||||
|         reserved_carts = cls._relevant_carts(user) | ||||
| 
 | ||||
|         quantity_in_reserved_carts = When( | ||||
|             discountitem__cart__in=reserved_carts, | ||||
|             then="discountitem__quantity" | ||||
|         ) | ||||
|         return discount_items | ||||
| 
 | ||||
|         quantity_or_zero = Case( | ||||
|             quantity_in_reserved_carts, | ||||
|             default=Value(0) | ||||
|         ) | ||||
| 
 | ||||
|         return quantity_or_zero | ||||
| 
 | ||||
| 
 | ||||
| class VoucherConditionController(ConditionController): | ||||
| class VoucherConditionController(IsMetByFilter, ConditionController): | ||||
|     ''' Condition test for VoucherFlag and VoucherDiscount.''' | ||||
| 
 | ||||
|     def __init__(self, condition): | ||||
|         self.condition = condition | ||||
|     @classmethod | ||||
|     def pre_filter(self, queryset, user): | ||||
|         ''' Returns all of the items from queryset where the user has entered | ||||
|         a voucher that invokes that item's condition in one of their carts. ''' | ||||
| 
 | ||||
|     def is_met(self, user): | ||||
|         ''' returns True if the user has the given voucher attached. ''' | ||||
|         carts_count = commerce.Cart.objects.filter( | ||||
|         carts = commerce.Cart.objects.filter( | ||||
|             user=user, | ||||
|             vouchers=self.condition.voucher, | ||||
|         ).count() | ||||
|         return carts_count > 0 | ||||
|         ) | ||||
|         vouchers = [cart.vouchers.all() for cart in carts] | ||||
| 
 | ||||
|         return queryset.filter(voucher__in=itertools.chain(*vouchers)) | ||||
|  |  | |||
|  | @ -4,7 +4,11 @@ from conditions import ConditionController | |||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
| from django.db.models import Case | ||||
| from django.db.models import Q | ||||
| from django.db.models import Sum | ||||
| from django.db.models import Value | ||||
| from django.db.models import When | ||||
| 
 | ||||
| 
 | ||||
| class DiscountAndQuantity(object): | ||||
|  | @ -43,6 +47,62 @@ def available_discounts(user, categories, products): | |||
|     and products. The discounts also list the available quantity for this user, | ||||
|     not including products that are pending purchase. ''' | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     filtered_clauses = _filtered_discounts(user, categories, products) | ||||
| 
 | ||||
|     discounts = [] | ||||
| 
 | ||||
|     # Markers so that we don't need to evaluate given conditions more than once | ||||
|     accepted_discounts = set() | ||||
|     failed_discounts = set() | ||||
| 
 | ||||
|     for clause in filtered_clauses: | ||||
|         discount = clause.discount | ||||
|         cond = ConditionController.for_condition(discount) | ||||
| 
 | ||||
|         past_use_count = discount.past_use_count | ||||
| 
 | ||||
|         # TODO: add test case -- | ||||
|         # discount covers 2x prod_1 and 1x prod_2 | ||||
|         # add 1x prod_2 | ||||
|         # add 1x prod_1 | ||||
|         # checkout | ||||
|         # discount should be available for prod_1 | ||||
| 
 | ||||
|         if past_use_count >= clause.quantity: | ||||
|             # This clause has exceeded its use count | ||||
|             pass | ||||
|         elif discount not in failed_discounts: | ||||
|             # This clause is still available | ||||
|             is_accepted = discount in accepted_discounts | ||||
|             if is_accepted or cond.is_met(user, filtered=True): | ||||
|                 # This clause is valid for this user | ||||
|                 discounts.append(DiscountAndQuantity( | ||||
|                     discount=discount, | ||||
|                     clause=clause, | ||||
|                     quantity=clause.quantity - past_use_count, | ||||
|                 )) | ||||
|                 accepted_discounts.add(discount) | ||||
|             else: | ||||
|                 # This clause is not valid for this user | ||||
|                 failed_discounts.add(discount) | ||||
|     return discounts | ||||
| 
 | ||||
| 
 | ||||
| def _filtered_discounts(user, categories, products): | ||||
|     ''' | ||||
| 
 | ||||
|     Returns: | ||||
|         Sequence[discountbase]: All discounts that passed the filter function. | ||||
| 
 | ||||
|     ''' | ||||
| 
 | ||||
|     types = list(ConditionController._controllers()) | ||||
|     discounttypes = [ | ||||
|         i for i in types if issubclass(i, conditions.DiscountBase) | ||||
|     ] | ||||
| 
 | ||||
|     # discounts that match provided categories | ||||
|     category_discounts = conditions.DiscountForCategory.objects.filter( | ||||
|         category__in=categories | ||||
|  | @ -67,51 +127,56 @@ def available_discounts(user, categories, products): | |||
|         "category", | ||||
|     ) | ||||
| 
 | ||||
|     valid_discounts = conditions.DiscountBase.objects.filter( | ||||
|         Q(discountforproduct__in=product_discounts) | | ||||
|         Q(discountforcategory__in=all_category_discounts) | ||||
|     ) | ||||
| 
 | ||||
|     all_subsets = [] | ||||
| 
 | ||||
|     for discounttype in discounttypes: | ||||
|         discounts = discounttype.objects.filter(id__in=valid_discounts) | ||||
|         ctrl = ConditionController.for_type(discounttype) | ||||
|         discounts = ctrl.pre_filter(discounts, user) | ||||
|         discounts = _annotate_with_past_uses(discounts, user) | ||||
|         all_subsets.append(discounts) | ||||
| 
 | ||||
|     filtered_discounts = list(itertools.chain(*all_subsets)) | ||||
| 
 | ||||
|     # Map from discount key to itself (contains annotations added by filter) | ||||
|     from_filter = dict((i.id, i) for i in filtered_discounts) | ||||
| 
 | ||||
|     # The set of all potential discounts | ||||
|     potential_discounts = set(itertools.chain( | ||||
|         product_discounts, | ||||
|         all_category_discounts, | ||||
|     discount_clauses = set(itertools.chain( | ||||
|         product_discounts.filter(discount__in=filtered_discounts), | ||||
|         all_category_discounts.filter(discount__in=filtered_discounts), | ||||
|     )) | ||||
| 
 | ||||
|     discounts = [] | ||||
|     # Replace discounts with the filtered ones | ||||
|     # These are the correct subclasses (saves query later on), and have | ||||
|     # correct annotations from filters if necessary. | ||||
|     for clause in discount_clauses: | ||||
|         clause.discount = from_filter[clause.discount.id] | ||||
| 
 | ||||
|     # Markers so that we don't need to evaluate given conditions more than once | ||||
|     accepted_discounts = set() | ||||
|     failed_discounts = set() | ||||
|     return discount_clauses | ||||
| 
 | ||||
|     for discount in potential_discounts: | ||||
|         real_discount = conditions.DiscountBase.objects.get_subclass( | ||||
|             pk=discount.discount.pk, | ||||
|         ) | ||||
|         cond = ConditionController.for_condition(real_discount) | ||||
| 
 | ||||
|         # Count the past uses of the given discount item. | ||||
|         # If this user has exceeded the limit for the clause, this clause | ||||
|         # is not available any more. | ||||
|         past_uses = commerce.DiscountItem.objects.filter( | ||||
|             cart__user=user, | ||||
|             cart__status=commerce.Cart.STATUS_PAID,  # Only past carts count | ||||
|             discount=real_discount, | ||||
|         ) | ||||
|         agg = past_uses.aggregate(Sum("quantity")) | ||||
|         past_use_count = agg["quantity__sum"] | ||||
|         if past_use_count is None: | ||||
|             past_use_count = 0 | ||||
| def _annotate_with_past_uses(queryset, user): | ||||
|     ''' Annotates the queryset with a usage count for that discount by the | ||||
|     given user. ''' | ||||
| 
 | ||||
|         if past_use_count >= discount.quantity: | ||||
|             # This clause has exceeded its use count | ||||
|             pass | ||||
|         elif real_discount not in failed_discounts: | ||||
|             # This clause is still available | ||||
|             if real_discount in accepted_discounts or cond.is_met(user): | ||||
|                 # This clause is valid for this user | ||||
|                 discounts.append(DiscountAndQuantity( | ||||
|                     discount=real_discount, | ||||
|                     clause=discount, | ||||
|                     quantity=discount.quantity - past_use_count, | ||||
|                 )) | ||||
|                 accepted_discounts.add(real_discount) | ||||
|             else: | ||||
|                 # This clause is not valid for this user | ||||
|                 failed_discounts.add(real_discount) | ||||
|     return discounts | ||||
|     past_use_quantity = When( | ||||
|         ( | ||||
|             Q(discountitem__cart__user=user) & | ||||
|             Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) | ||||
|         ), | ||||
|         then="discountitem__quantity", | ||||
|     ) | ||||
| 
 | ||||
|     past_use_quantity_or_zero = Case( | ||||
|         past_use_quantity, | ||||
|         default=Value(0), | ||||
|     ) | ||||
| 
 | ||||
|     queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) | ||||
|     return queryset | ||||
|  |  | |||
|  | @ -6,6 +6,8 @@ from django.core.exceptions import ValidationError | |||
| from controller_helpers import TestingCartController | ||||
| from test_cart import RegistrationCartTestCase | ||||
| 
 | ||||
| from registrasion.controllers.discount import available_discounts | ||||
| from registrasion.controllers.product import ProductController | ||||
| from registrasion.models import commerce | ||||
| from registrasion.models import conditions | ||||
| 
 | ||||
|  | @ -135,6 +137,39 @@ class CeilingsTestCases(RegistrationCartTestCase): | |||
|         with self.assertRaises(ValidationError): | ||||
|             first_cart.validate_cart() | ||||
| 
 | ||||
|     def test_discount_ceiling_aggregates_products(self): | ||||
|         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||
|         # after second. | ||||
|         self.make_discount_ceiling( | ||||
|             "Multi-product limit discount ceiling", | ||||
|             limit=2, | ||||
|         ) | ||||
|         for i in xrange(2): | ||||
|             cart = TestingCartController.for_user(self.USER_1) | ||||
|             cart.add_to_cart(self.PROD_1, 1) | ||||
|             cart.next_cart() | ||||
| 
 | ||||
|         discounts = available_discounts(self.USER_1, [], [self.PROD_1]) | ||||
| 
 | ||||
|         self.assertEqual(0, len(discounts)) | ||||
| 
 | ||||
|     def test_flag_ceiling_aggregates_products(self): | ||||
|         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||
|         # after second. | ||||
|         self.make_ceiling("Multi-product limit ceiling", limit=2) | ||||
| 
 | ||||
|         for i in xrange(2): | ||||
|             cart = TestingCartController.for_user(self.USER_1) | ||||
|             cart.add_to_cart(self.PROD_1, 1) | ||||
|             cart.next_cart() | ||||
| 
 | ||||
|         products = ProductController.available_products( | ||||
|             self.USER_1, | ||||
|             products=[self.PROD_1], | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual(0, len(products)) | ||||
| 
 | ||||
|     def test_items_released_from_ceiling_by_refund(self): | ||||
|         self.make_ceiling("Limit ceiling", limit=1) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,2 +1,2 @@ | |||
| [flake8] | ||||
| exclude = registrasion/migrations/*, build/*, docs/* | ||||
| exclude = registrasion/migrations/*, build/*, docs/*, dist/* | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Christopher Neugebauer
						Christopher Neugebauer