diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 20bd6b5c..e7282e9a 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -307,6 +307,8 @@ class CartController(object): self._append_errors(errors, ve) # Validate the discounts + # TODO: refactor in terms of available_discounts + # why aren't we doing that here?! discount_items = commerce.DiscountItem.objects.filter(cart=cart) seen_discounts = set() diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index db40d0c2..291a573a 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -4,7 +4,12 @@ import operator from collections import defaultdict from collections import namedtuple +from django.db.models import Case +from django.db.models import Count +from django.db.models import F, Q from django.db.models import Sum +from django.db.models import Value +from django.db.models import When from django.utils import timezone from registrasion.models import commerce @@ -12,6 +17,7 @@ from registrasion.models import conditions from registrasion.models import inventory + ConditionAndRemainder = namedtuple( "ConditionAndRemainder", ( @@ -21,16 +27,77 @@ ConditionAndRemainder = namedtuple( ) +_FlagCounter = namedtuple( + "_FlagCounter", + ( + "products", + "categories", + ), +) + + +_ConditionsCount = namedtuple( + "ConditionsCount", + ( + "dif", + "eit", + ), +) + + +class FlagCounter(_FlagCounter): + + @classmethod + def count(cls): + # Get the count of how many conditions should exist per product + flagbases = conditions.FlagBase.objects + + types = ( + conditions.FlagBase.ENABLE_IF_TRUE, + conditions.FlagBase.DISABLE_IF_FALSE, + ) + keys = ("eit", "dif") + flags = [ + flagbases.filter( + condition=condition_type + ).values( + 'products', 'categories' + ).annotate( + count=Count('id') + ) + for condition_type in types + ] + + cats = defaultdict(lambda: defaultdict(int)) + prod = defaultdict(lambda: defaultdict(int)) + + for key, flagcounts in zip(keys, flags): + for row in flagcounts: + if row["products"] is not None: + prod[row["products"]][key] = row["count"] + if row["categories"] is not None: + cats[row["categories"]][key] = row["count"] + + return cls(products=prod, categories=cats) + + def get(self, product): + p = self.products[product.id] + c = self.categories[product.category.id] + eit = p["eit"] + c["eit"] + dif = p["dif"] + c["dif"] + return _ConditionsCount(dif=dif, eit=eit) + + class ConditionController(object): ''' Base class for testing conditions that activate Flag or Discount objects. ''' - def __init__(self): - pass + def __init__(self, condition): + self.condition = condition @staticmethod - def for_condition(condition): - CONTROLLERS = { + def _controllers(): + return { conditions.CategoryFlag: CategoryConditionController, conditions.IncludedProductDiscount: ProductConditionController, conditions.ProductFlag: ProductConditionController, @@ -42,8 +109,14 @@ class ConditionController(object): conditions.VoucherFlag: VoucherConditionController, } + @staticmethod + def for_type(cls): + return ConditionController._controllers()[cls] + + @staticmethod + def for_condition(condition): try: - return CONTROLLERS[type(condition)](condition) + return ConditionController.for_type(type(condition))(condition) except KeyError: return ConditionController() @@ -91,20 +164,9 @@ class ConditionController(object): products = set(products) quantities = {} - # Get the conditions covered by the products themselves - prods = ( - product.flagbase_set.select_subclasses() - for product in products - ) - # Get the conditions covered by their categories - cats = ( - category.flagbase_set.select_subclasses() - for category in set(product.category for product in products) - ) - if products: # Simplify the query. - all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) + all_conditions = cls._filtered_flags(user, products) else: all_conditions = [] @@ -114,11 +176,15 @@ class ConditionController(object): do_enable = defaultdict(lambda: False) # (if either sort of condition is present) + # Count the number of conditions for a product + dif_count = defaultdict(int) + eit_count = defaultdict(int) + messages = {} for condition in all_conditions: cond = cls.for_condition(condition) - remainder = cond.user_quantity_remaining(user) + remainder = cond.user_quantity_remaining(user, filtered=True) # Get all products covered by this condition, and the products # from the categories covered by this condition @@ -149,14 +215,41 @@ class ConditionController(object): for product in all_products: if condition.is_disable_if_false: do_not_disable[product] &= met + dif_count[product] += 1 else: do_enable[product] |= met + eit_count[product] += 1 if not met and product not in messages: messages[product] = message + total_flags = FlagCounter.count() + valid = {} + + # the problem is that now, not every condition falls into + # do_not_disable or do_enable ''' + # You should look into this, chris :) + + for product in products: + if quantities: + if quantities[product] == 0: + continue + + f = total_flags.get(product) + if f.dif > 0 and f.dif != dif_count[product]: + do_not_disable[product] = False + if product not in messages: + messages[product] = "Some disable-if-false " \ + "conditions were not met" + if f.eit > 0 and product not in do_enable: + do_enable[product] = False + if product not in messages: + messages[product] = "Some enable-if-true " \ + "conditions were not met" + for product in itertools.chain(do_not_disable, do_enable): + f = total_flags.get(product) if product in do_enable: # If there's an enable-if-true, we need need of those met too. # (do_not_disable will default to true otherwise) @@ -172,7 +265,71 @@ class ConditionController(object): return error_fields - def user_quantity_remaining(self, user): + @classmethod + def _filtered_flags(cls, user, products): + ''' + + Returns: + Sequence[flagbase]: All flags that passed the filter function. + + ''' + + types = list(ConditionController._controllers()) + flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] + + # Get all flags for the products and categories. + prods = ( + product.flagbase_set.all() + for product in products + ) + cats = ( + category.flagbase_set.all() + for category in set(product.category for product in products) + ) + all_flags = reduce(operator.or_, itertools.chain(prods, cats)) + + all_subsets = [] + + for flagtype in flagtypes: + flags = flagtype.objects.filter(id__in=all_flags) + ctrl = ConditionController.for_type(flagtype) + flags = ctrl.pre_filter(flags, user) + all_subsets.append(flags) + + return itertools.chain(*all_subsets) + + @classmethod + def pre_filter(cls, queryset, user): + ''' Returns only the flag conditions that might be available for this + user. It should hopefully reduce the number of queries that need to be + executed to determine if a flag is met. + + If this filtration implements the same query as is_met, then you should + be able to implement ``is_met()`` in terms of this. + + Arguments: + + queryset (Queryset[c]): The canditate conditions. + + user (User): The user for whom we're testing these conditions. + + Returns: + Queryset[c]: A subset of the conditions that pass the pre-filter + test for this user. + + ''' + + # Default implementation does NOTHING. + return queryset + + def passes_filter(self, user): + ''' Returns true if the condition passes the filter ''' + + cls = type(self.condition) + qs = cls.objects.filter(pk=self.condition.id) + return self.condition in self.pre_filter(qs, user) + + def user_quantity_remaining(self, user, filtered=False): ''' Returns the number of items covered by this flag condition the user can add to the current cart. This default implementation returns a big number if is_met() is true, otherwise 0. @@ -180,26 +337,37 @@ class ConditionController(object): Either this method, or is_met() must be overridden in subclasses. ''' - return 99999999 if self.is_met(user) else 0 + return _BIG_QUANTITY if self.is_met(user, filtered) else 0 - def is_met(self, user): + def is_met(self, user, filtered=False): ''' Returns True if this flag condition is met, otherwise returns False. Either this method, or user_quantity_remaining() must be overridden in subclasses. + + Arguments: + + user (User): The user for whom this test must be met. + + filter (bool): If true, this condition was part of a queryset + returned by pre_filter() for this user. + ''' - return self.user_quantity_remaining(user) > 0 + return self.user_quantity_remaining(user, filtered) > 0 -class CategoryConditionController(ConditionController): +class IsMetByFilter(object): - def __init__(self, condition): - self.condition = condition + def is_met(self, user, filtered=False): + ''' Returns True if this flag condition is met, otherwise returns + False. It determines if the condition is met by calling pre_filter + with a queryset containing only self.condition. ''' - def is_met(self, user): - ''' returns True if the user has a product from a category that invokes - this condition in one of their carts ''' + if filtered: + return True # Why query again? + + return self.passes_filter(user) carts = commerce.Cart.objects.filter(user=user) carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) @@ -212,112 +380,176 @@ class CategoryConditionController(ConditionController): ).count() return products_count > 0 +class RemainderSetByFilter(object): -class ProductConditionController(ConditionController): + def user_quantity_remaining(self, user, filtered=True): + ''' returns 0 if the date range is violated, otherwise, it will return + the quantity remaining under the stock limit. + + The filter for this condition must add an annotation called "remainder" + in order for this to work. + ''' + + if filtered: + if hasattr(self.condition, "remainder"): + return self.condition.remainder + + + + # Mark self.condition with a remainder + qs = type(self.condition).objects.filter(pk=self.condition.id) + qs = self.pre_filter(qs, user) + + if len(qs) > 0: + return qs[0].remainder + else: + return 0 + + +class CategoryConditionController(IsMetByFilter, ConditionController): + + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has a + product from a category invoking that item's condition in one of their + carts. ''' + + items = commerce.ProductItem.objects.filter(cart__user=user) + items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) + items = items.select_related("product", "product__category") + categories = [item.product.category for item in items] + + return queryset.filter(enabling_category__in=categories) + + +class ProductConditionController(IsMetByFilter, ConditionController): ''' Condition tests for ProductFlag and IncludedProductDiscount. ''' - def __init__(self, condition): - self.condition = condition + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has a + product invoking that item's condition in one of their carts. ''' - def is_met(self, user): - ''' returns True if the user has a product that invokes this - condition in one of their carts ''' + items = commerce.ProductItem.objects.filter(cart__user=user) + items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) + items = items.select_related("product", "product__category") + products = [item.product for item in items] - carts = commerce.Cart.objects.filter(user=user) - carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) - products_count = commerce.ProductItem.objects.filter( - cart__in=carts, - product__in=self.condition.enabling_products.all(), - ).count() - return products_count > 0 + return queryset.filter(enabling_products__in=products) -class TimeOrStockLimitConditionController(ConditionController): +class TimeOrStockLimitConditionController( + RemainderSetByFilter, + ConditionController, + ): ''' Common condition tests for TimeOrStockLimit Flag and Discount.''' - def __init__(self, ceiling): - self.ceiling = ceiling + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the date falls into + any specified range, but not yet where the stock limit is not yet + reached.''' - def user_quantity_remaining(self, user): - ''' returns 0 if the date range is violated, otherwise, it will return - the quantity remaining under the stock limit. ''' - - # Test date range - if not self._test_date_range(): - return 0 - - return self._get_remaining_stock(user) - - def _test_date_range(self): now = timezone.now() - if self.ceiling.start_time is not None: - if now < self.ceiling.start_time: - return False + # Keep items with no start time, or start time not yet met. + queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) + queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) - if self.ceiling.end_time is not None: - if now > self.ceiling.end_time: - return False + # Filter out items that have been reserved beyond the limits + quantity_or_zero = self._calculate_quantities(user) - return True + remainder = Case( + When(limit=None, then=Value(_BIG_QUANTITY)), + default=F("limit") - Sum(quantity_or_zero), + ) - def _get_remaining_stock(self, user): - ''' Returns the stock that remains under this ceiling, excluding the - user's current cart. ''' + queryset = queryset.annotate(remainder=remainder) + queryset = queryset.filter(remainder__gt=0) - if self.ceiling.limit is None: - return 99999999 + return queryset - # We care about all reserved carts, but not the user's current cart + @classmethod + def _relevant_carts(cls, user): reserved_carts = commerce.Cart.reserved_carts() reserved_carts = reserved_carts.exclude( user=user, status=commerce.Cart.STATUS_ACTIVE, ) - - items = self._items() - items = items.filter(cart__in=reserved_carts) - count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 - - return self.ceiling.limit - count + return reserved_carts class TimeOrStockLimitFlagController( TimeOrStockLimitConditionController): - def _items(self): - category_products = inventory.Product.objects.filter( - category__in=self.ceiling.categories.all(), - ) - products = self.ceiling.products.all() | category_products + @classmethod + def _calculate_quantities(cls, user): + reserved_carts = cls._relevant_carts(user) - product_items = commerce.ProductItem.objects.filter( - product__in=products.all(), + # Calculate category lines + cat_items = F('categories__product__productitem__product__category') + reserved_category_products = ( + Q(categories=cat_items) & + Q(categories__product__productitem__cart__in=reserved_carts) ) - return product_items + + # Calculate product lines + reserved_products = ( + Q(products=F('products__productitem__product')) & + Q(products__productitem__cart__in=reserved_carts) + ) + + category_quantity_in_reserved_carts = When( + reserved_category_products, + then="categories__product__productitem__quantity", + ) + + product_quantity_in_reserved_carts = When( + reserved_products, + then="products__productitem__quantity", + ) + + quantity_or_zero = Case( + category_quantity_in_reserved_carts, + product_quantity_in_reserved_carts, + default=Value(0), + ) + + return quantity_or_zero class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): - def _items(self): - discount_items = commerce.DiscountItem.objects.filter( - discount=self.ceiling, + @classmethod + def _calculate_quantities(cls, user): + reserved_carts = cls._relevant_carts(user) + + quantity_in_reserved_carts = When( + discountitem__cart__in=reserved_carts, + then="discountitem__quantity" ) - return discount_items + + quantity_or_zero = Case( + quantity_in_reserved_carts, + default=Value(0) + ) + + return quantity_or_zero -class VoucherConditionController(ConditionController): +class VoucherConditionController(IsMetByFilter, ConditionController): ''' Condition test for VoucherFlag and VoucherDiscount.''' - def __init__(self, condition): - self.condition = condition + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has entered + a voucher that invokes that item's condition in one of their carts. ''' - def is_met(self, user): - ''' returns True if the user has the given voucher attached. ''' - carts_count = commerce.Cart.objects.filter( + carts = commerce.Cart.objects.filter( user=user, - vouchers=self.condition.voucher, - ).count() - return carts_count > 0 + ) + vouchers = [cart.vouchers.all() for cart in carts] + + return queryset.filter(voucher__in=itertools.chain(*vouchers)) diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index a8f9282e..21d8a51b 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -4,7 +4,11 @@ from conditions import ConditionController from registrasion.models import commerce from registrasion.models import conditions +from django.db.models import Case +from django.db.models import Q from django.db.models import Sum +from django.db.models import Value +from django.db.models import When class DiscountAndQuantity(object): @@ -43,6 +47,62 @@ def available_discounts(user, categories, products): and products. The discounts also list the available quantity for this user, not including products that are pending purchase. ''' + + + filtered_clauses = _filtered_discounts(user, categories, products) + + discounts = [] + + # Markers so that we don't need to evaluate given conditions more than once + accepted_discounts = set() + failed_discounts = set() + + for clause in filtered_clauses: + discount = clause.discount + cond = ConditionController.for_condition(discount) + + past_use_count = discount.past_use_count + + # TODO: add test case -- + # discount covers 2x prod_1 and 1x prod_2 + # add 1x prod_2 + # add 1x prod_1 + # checkout + # discount should be available for prod_1 + + if past_use_count >= clause.quantity: + # This clause has exceeded its use count + pass + elif discount not in failed_discounts: + # This clause is still available + is_accepted = discount in accepted_discounts + if is_accepted or cond.is_met(user, filtered=True): + # This clause is valid for this user + discounts.append(DiscountAndQuantity( + discount=discount, + clause=clause, + quantity=clause.quantity - past_use_count, + )) + accepted_discounts.add(discount) + else: + # This clause is not valid for this user + failed_discounts.add(discount) + return discounts + + +def _filtered_discounts(user, categories, products): + ''' + + Returns: + Sequence[discountbase]: All discounts that passed the filter function. + + ''' + + types = list(ConditionController._controllers()) + discounttypes = [ + i for i in types if issubclass(i, conditions.DiscountBase) + ] + # discounts that match provided categories category_discounts = conditions.DiscountForCategory.objects.filter( category__in=categories @@ -67,51 +127,56 @@ def available_discounts(user, categories, products): "category", ) + valid_discounts = conditions.DiscountBase.objects.filter( + Q(discountforproduct__in=product_discounts) | + Q(discountforcategory__in=all_category_discounts) + ) + + all_subsets = [] + + for discounttype in discounttypes: + discounts = discounttype.objects.filter(id__in=valid_discounts) + ctrl = ConditionController.for_type(discounttype) + discounts = ctrl.pre_filter(discounts, user) + discounts = _annotate_with_past_uses(discounts, user) + all_subsets.append(discounts) + + filtered_discounts = list(itertools.chain(*all_subsets)) + + # Map from discount key to itself (contains annotations added by filter) + from_filter = dict((i.id, i) for i in filtered_discounts) + # The set of all potential discounts - potential_discounts = set(itertools.chain( - product_discounts, - all_category_discounts, + discount_clauses = set(itertools.chain( + product_discounts.filter(discount__in=filtered_discounts), + all_category_discounts.filter(discount__in=filtered_discounts), )) - discounts = [] + # Replace discounts with the filtered ones + # These are the correct subclasses (saves query later on), and have + # correct annotations from filters if necessary. + for clause in discount_clauses: + clause.discount = from_filter[clause.discount.id] - # Markers so that we don't need to evaluate given conditions more than once - accepted_discounts = set() - failed_discounts = set() + return discount_clauses - for discount in potential_discounts: - real_discount = conditions.DiscountBase.objects.get_subclass( - pk=discount.discount.pk, - ) - cond = ConditionController.for_condition(real_discount) - # Count the past uses of the given discount item. - # If this user has exceeded the limit for the clause, this clause - # is not available any more. - past_uses = commerce.DiscountItem.objects.filter( - cart__user=user, - cart__status=commerce.Cart.STATUS_PAID, # Only past carts count - discount=real_discount, - ) - agg = past_uses.aggregate(Sum("quantity")) - past_use_count = agg["quantity__sum"] - if past_use_count is None: - past_use_count = 0 +def _annotate_with_past_uses(queryset, user): + ''' Annotates the queryset with a usage count for that discount by the + given user. ''' - if past_use_count >= discount.quantity: - # This clause has exceeded its use count - pass - elif real_discount not in failed_discounts: - # This clause is still available - if real_discount in accepted_discounts or cond.is_met(user): - # This clause is valid for this user - discounts.append(DiscountAndQuantity( - discount=real_discount, - clause=discount, - quantity=discount.quantity - past_use_count, - )) - accepted_discounts.add(real_discount) - else: - # This clause is not valid for this user - failed_discounts.add(real_discount) - return discounts + past_use_quantity = When( + ( + Q(discountitem__cart__user=user) & + Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) + ), + then="discountitem__quantity", + ) + + past_use_quantity_or_zero = Case( + past_use_quantity, + default=Value(0), + ) + + queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) + return queryset diff --git a/registrasion/tests/test_ceilings.py b/registrasion/tests/test_ceilings.py index 877556d0..1273a071 100644 --- a/registrasion/tests/test_ceilings.py +++ b/registrasion/tests/test_ceilings.py @@ -6,6 +6,8 @@ from django.core.exceptions import ValidationError from controller_helpers import TestingCartController from test_cart import RegistrationCartTestCase +from registrasion.controllers.discount import available_discounts +from registrasion.controllers.product import ProductController from registrasion.models import commerce from registrasion.models import conditions @@ -135,6 +137,39 @@ class CeilingsTestCases(RegistrationCartTestCase): with self.assertRaises(ValidationError): first_cart.validate_cart() + def test_discount_ceiling_aggregates_products(self): + # Create two carts, add 1xprod_1 to each. Ceiling should disappear + # after second. + self.make_discount_ceiling( + "Multi-product limit discount ceiling", + limit=2, + ) + for i in xrange(2): + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + cart.next_cart() + + discounts = available_discounts(self.USER_1, [], [self.PROD_1]) + + self.assertEqual(0, len(discounts)) + + def test_flag_ceiling_aggregates_products(self): + # Create two carts, add 1xprod_1 to each. Ceiling should disappear + # after second. + self.make_ceiling("Multi-product limit ceiling", limit=2) + + for i in xrange(2): + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + cart.next_cart() + + products = ProductController.available_products( + self.USER_1, + products=[self.PROD_1], + ) + + self.assertEqual(0, len(products)) + def test_items_released_from_ceiling_by_refund(self): self.make_ceiling("Limit ceiling", limit=1) diff --git a/setup.cfg b/setup.cfg index 290fdb45..9b05640c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [flake8] -exclude = registrasion/migrations/*, build/*, docs/* +exclude = registrasion/migrations/*, build/*, docs/*, dist/*