From 3f1be0e14e06b5a24ee5669648d823d84c47f55b Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Wed, 27 Apr 2016 11:46:44 +1000 Subject: [PATCH 01/15] Rearchitected condition processing such that multiple conditions are processed by the database, in bulk. Closes #42. --- registrasion/controllers/cart.py | 2 + registrasion/controllers/conditions.py | 422 +++++++++++++++++++------ registrasion/controllers/discount.py | 147 ++++++--- registrasion/tests/test_ceilings.py | 35 ++ setup.cfg | 2 +- 5 files changed, 471 insertions(+), 137 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 20bd6b5c..e7282e9a 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -307,6 +307,8 @@ class CartController(object): self._append_errors(errors, ve) # Validate the discounts + # TODO: refactor in terms of available_discounts + # why aren't we doing that here?! discount_items = commerce.DiscountItem.objects.filter(cart=cart) seen_discounts = set() diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index db40d0c2..291a573a 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -4,7 +4,12 @@ import operator from collections import defaultdict from collections import namedtuple +from django.db.models import Case +from django.db.models import Count +from django.db.models import F, Q from django.db.models import Sum +from django.db.models import Value +from django.db.models import When from django.utils import timezone from registrasion.models import commerce @@ -12,6 +17,7 @@ from registrasion.models import conditions from registrasion.models import inventory + ConditionAndRemainder = namedtuple( "ConditionAndRemainder", ( @@ -21,16 +27,77 @@ ConditionAndRemainder = namedtuple( ) +_FlagCounter = namedtuple( + "_FlagCounter", + ( + "products", + "categories", + ), +) + + +_ConditionsCount = namedtuple( + "ConditionsCount", + ( + "dif", + "eit", + ), +) + + +class FlagCounter(_FlagCounter): + + @classmethod + def count(cls): + # Get the count of how many conditions should exist per product + flagbases = conditions.FlagBase.objects + + types = ( + conditions.FlagBase.ENABLE_IF_TRUE, + conditions.FlagBase.DISABLE_IF_FALSE, + ) + keys = ("eit", "dif") + flags = [ + flagbases.filter( + condition=condition_type + ).values( + 'products', 'categories' + ).annotate( + count=Count('id') + ) + for condition_type in types + ] + + cats = defaultdict(lambda: defaultdict(int)) + prod = defaultdict(lambda: defaultdict(int)) + + for key, flagcounts in zip(keys, flags): + for row in flagcounts: + if row["products"] is not None: + prod[row["products"]][key] = row["count"] + if row["categories"] is not None: + cats[row["categories"]][key] = row["count"] + + return cls(products=prod, categories=cats) + + def get(self, product): + p = self.products[product.id] + c = self.categories[product.category.id] + eit = p["eit"] + c["eit"] + dif = p["dif"] + c["dif"] + return _ConditionsCount(dif=dif, eit=eit) + + class ConditionController(object): ''' Base class for testing conditions that activate Flag or Discount objects. ''' - def __init__(self): - pass + def __init__(self, condition): + self.condition = condition @staticmethod - def for_condition(condition): - CONTROLLERS = { + def _controllers(): + return { conditions.CategoryFlag: CategoryConditionController, conditions.IncludedProductDiscount: ProductConditionController, conditions.ProductFlag: ProductConditionController, @@ -42,8 +109,14 @@ class ConditionController(object): conditions.VoucherFlag: VoucherConditionController, } + @staticmethod + def for_type(cls): + return ConditionController._controllers()[cls] + + @staticmethod + def for_condition(condition): try: - return CONTROLLERS[type(condition)](condition) + return ConditionController.for_type(type(condition))(condition) except KeyError: return ConditionController() @@ -91,20 +164,9 @@ class ConditionController(object): products = set(products) quantities = {} - # Get the conditions covered by the products themselves - prods = ( - product.flagbase_set.select_subclasses() - for product in products - ) - # Get the conditions covered by their categories - cats = ( - category.flagbase_set.select_subclasses() - for category in set(product.category for product in products) - ) - if products: # Simplify the query. - all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) + all_conditions = cls._filtered_flags(user, products) else: all_conditions = [] @@ -114,11 +176,15 @@ class ConditionController(object): do_enable = defaultdict(lambda: False) # (if either sort of condition is present) + # Count the number of conditions for a product + dif_count = defaultdict(int) + eit_count = defaultdict(int) + messages = {} for condition in all_conditions: cond = cls.for_condition(condition) - remainder = cond.user_quantity_remaining(user) + remainder = cond.user_quantity_remaining(user, filtered=True) # Get all products covered by this condition, and the products # from the categories covered by this condition @@ -149,14 +215,41 @@ class ConditionController(object): for product in all_products: if condition.is_disable_if_false: do_not_disable[product] &= met + dif_count[product] += 1 else: do_enable[product] |= met + eit_count[product] += 1 if not met and product not in messages: messages[product] = message + total_flags = FlagCounter.count() + valid = {} + + # the problem is that now, not every condition falls into + # do_not_disable or do_enable ''' + # You should look into this, chris :) + + for product in products: + if quantities: + if quantities[product] == 0: + continue + + f = total_flags.get(product) + if f.dif > 0 and f.dif != dif_count[product]: + do_not_disable[product] = False + if product not in messages: + messages[product] = "Some disable-if-false " \ + "conditions were not met" + if f.eit > 0 and product not in do_enable: + do_enable[product] = False + if product not in messages: + messages[product] = "Some enable-if-true " \ + "conditions were not met" + for product in itertools.chain(do_not_disable, do_enable): + f = total_flags.get(product) if product in do_enable: # If there's an enable-if-true, we need need of those met too. # (do_not_disable will default to true otherwise) @@ -172,7 +265,71 @@ class ConditionController(object): return error_fields - def user_quantity_remaining(self, user): + @classmethod + def _filtered_flags(cls, user, products): + ''' + + Returns: + Sequence[flagbase]: All flags that passed the filter function. + + ''' + + types = list(ConditionController._controllers()) + flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] + + # Get all flags for the products and categories. + prods = ( + product.flagbase_set.all() + for product in products + ) + cats = ( + category.flagbase_set.all() + for category in set(product.category for product in products) + ) + all_flags = reduce(operator.or_, itertools.chain(prods, cats)) + + all_subsets = [] + + for flagtype in flagtypes: + flags = flagtype.objects.filter(id__in=all_flags) + ctrl = ConditionController.for_type(flagtype) + flags = ctrl.pre_filter(flags, user) + all_subsets.append(flags) + + return itertools.chain(*all_subsets) + + @classmethod + def pre_filter(cls, queryset, user): + ''' Returns only the flag conditions that might be available for this + user. It should hopefully reduce the number of queries that need to be + executed to determine if a flag is met. + + If this filtration implements the same query as is_met, then you should + be able to implement ``is_met()`` in terms of this. + + Arguments: + + queryset (Queryset[c]): The canditate conditions. + + user (User): The user for whom we're testing these conditions. + + Returns: + Queryset[c]: A subset of the conditions that pass the pre-filter + test for this user. + + ''' + + # Default implementation does NOTHING. + return queryset + + def passes_filter(self, user): + ''' Returns true if the condition passes the filter ''' + + cls = type(self.condition) + qs = cls.objects.filter(pk=self.condition.id) + return self.condition in self.pre_filter(qs, user) + + def user_quantity_remaining(self, user, filtered=False): ''' Returns the number of items covered by this flag condition the user can add to the current cart. This default implementation returns a big number if is_met() is true, otherwise 0. @@ -180,26 +337,37 @@ class ConditionController(object): Either this method, or is_met() must be overridden in subclasses. ''' - return 99999999 if self.is_met(user) else 0 + return _BIG_QUANTITY if self.is_met(user, filtered) else 0 - def is_met(self, user): + def is_met(self, user, filtered=False): ''' Returns True if this flag condition is met, otherwise returns False. Either this method, or user_quantity_remaining() must be overridden in subclasses. + + Arguments: + + user (User): The user for whom this test must be met. + + filter (bool): If true, this condition was part of a queryset + returned by pre_filter() for this user. + ''' - return self.user_quantity_remaining(user) > 0 + return self.user_quantity_remaining(user, filtered) > 0 -class CategoryConditionController(ConditionController): +class IsMetByFilter(object): - def __init__(self, condition): - self.condition = condition + def is_met(self, user, filtered=False): + ''' Returns True if this flag condition is met, otherwise returns + False. It determines if the condition is met by calling pre_filter + with a queryset containing only self.condition. ''' - def is_met(self, user): - ''' returns True if the user has a product from a category that invokes - this condition in one of their carts ''' + if filtered: + return True # Why query again? + + return self.passes_filter(user) carts = commerce.Cart.objects.filter(user=user) carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) @@ -212,112 +380,176 @@ class CategoryConditionController(ConditionController): ).count() return products_count > 0 +class RemainderSetByFilter(object): -class ProductConditionController(ConditionController): + def user_quantity_remaining(self, user, filtered=True): + ''' returns 0 if the date range is violated, otherwise, it will return + the quantity remaining under the stock limit. + + The filter for this condition must add an annotation called "remainder" + in order for this to work. + ''' + + if filtered: + if hasattr(self.condition, "remainder"): + return self.condition.remainder + + + + # Mark self.condition with a remainder + qs = type(self.condition).objects.filter(pk=self.condition.id) + qs = self.pre_filter(qs, user) + + if len(qs) > 0: + return qs[0].remainder + else: + return 0 + + +class CategoryConditionController(IsMetByFilter, ConditionController): + + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has a + product from a category invoking that item's condition in one of their + carts. ''' + + items = commerce.ProductItem.objects.filter(cart__user=user) + items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) + items = items.select_related("product", "product__category") + categories = [item.product.category for item in items] + + return queryset.filter(enabling_category__in=categories) + + +class ProductConditionController(IsMetByFilter, ConditionController): ''' Condition tests for ProductFlag and IncludedProductDiscount. ''' - def __init__(self, condition): - self.condition = condition + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has a + product invoking that item's condition in one of their carts. ''' - def is_met(self, user): - ''' returns True if the user has a product that invokes this - condition in one of their carts ''' + items = commerce.ProductItem.objects.filter(cart__user=user) + items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) + items = items.select_related("product", "product__category") + products = [item.product for item in items] - carts = commerce.Cart.objects.filter(user=user) - carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) - products_count = commerce.ProductItem.objects.filter( - cart__in=carts, - product__in=self.condition.enabling_products.all(), - ).count() - return products_count > 0 + return queryset.filter(enabling_products__in=products) -class TimeOrStockLimitConditionController(ConditionController): +class TimeOrStockLimitConditionController( + RemainderSetByFilter, + ConditionController, + ): ''' Common condition tests for TimeOrStockLimit Flag and Discount.''' - def __init__(self, ceiling): - self.ceiling = ceiling + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the date falls into + any specified range, but not yet where the stock limit is not yet + reached.''' - def user_quantity_remaining(self, user): - ''' returns 0 if the date range is violated, otherwise, it will return - the quantity remaining under the stock limit. ''' - - # Test date range - if not self._test_date_range(): - return 0 - - return self._get_remaining_stock(user) - - def _test_date_range(self): now = timezone.now() - if self.ceiling.start_time is not None: - if now < self.ceiling.start_time: - return False + # Keep items with no start time, or start time not yet met. + queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) + queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) - if self.ceiling.end_time is not None: - if now > self.ceiling.end_time: - return False + # Filter out items that have been reserved beyond the limits + quantity_or_zero = self._calculate_quantities(user) - return True + remainder = Case( + When(limit=None, then=Value(_BIG_QUANTITY)), + default=F("limit") - Sum(quantity_or_zero), + ) - def _get_remaining_stock(self, user): - ''' Returns the stock that remains under this ceiling, excluding the - user's current cart. ''' + queryset = queryset.annotate(remainder=remainder) + queryset = queryset.filter(remainder__gt=0) - if self.ceiling.limit is None: - return 99999999 + return queryset - # We care about all reserved carts, but not the user's current cart + @classmethod + def _relevant_carts(cls, user): reserved_carts = commerce.Cart.reserved_carts() reserved_carts = reserved_carts.exclude( user=user, status=commerce.Cart.STATUS_ACTIVE, ) - - items = self._items() - items = items.filter(cart__in=reserved_carts) - count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 - - return self.ceiling.limit - count + return reserved_carts class TimeOrStockLimitFlagController( TimeOrStockLimitConditionController): - def _items(self): - category_products = inventory.Product.objects.filter( - category__in=self.ceiling.categories.all(), - ) - products = self.ceiling.products.all() | category_products + @classmethod + def _calculate_quantities(cls, user): + reserved_carts = cls._relevant_carts(user) - product_items = commerce.ProductItem.objects.filter( - product__in=products.all(), + # Calculate category lines + cat_items = F('categories__product__productitem__product__category') + reserved_category_products = ( + Q(categories=cat_items) & + Q(categories__product__productitem__cart__in=reserved_carts) ) - return product_items + + # Calculate product lines + reserved_products = ( + Q(products=F('products__productitem__product')) & + Q(products__productitem__cart__in=reserved_carts) + ) + + category_quantity_in_reserved_carts = When( + reserved_category_products, + then="categories__product__productitem__quantity", + ) + + product_quantity_in_reserved_carts = When( + reserved_products, + then="products__productitem__quantity", + ) + + quantity_or_zero = Case( + category_quantity_in_reserved_carts, + product_quantity_in_reserved_carts, + default=Value(0), + ) + + return quantity_or_zero class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): - def _items(self): - discount_items = commerce.DiscountItem.objects.filter( - discount=self.ceiling, + @classmethod + def _calculate_quantities(cls, user): + reserved_carts = cls._relevant_carts(user) + + quantity_in_reserved_carts = When( + discountitem__cart__in=reserved_carts, + then="discountitem__quantity" ) - return discount_items + + quantity_or_zero = Case( + quantity_in_reserved_carts, + default=Value(0) + ) + + return quantity_or_zero -class VoucherConditionController(ConditionController): +class VoucherConditionController(IsMetByFilter, ConditionController): ''' Condition test for VoucherFlag and VoucherDiscount.''' - def __init__(self, condition): - self.condition = condition + @classmethod + def pre_filter(self, queryset, user): + ''' Returns all of the items from queryset where the user has entered + a voucher that invokes that item's condition in one of their carts. ''' - def is_met(self, user): - ''' returns True if the user has the given voucher attached. ''' - carts_count = commerce.Cart.objects.filter( + carts = commerce.Cart.objects.filter( user=user, - vouchers=self.condition.voucher, - ).count() - return carts_count > 0 + ) + vouchers = [cart.vouchers.all() for cart in carts] + + return queryset.filter(voucher__in=itertools.chain(*vouchers)) diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index a8f9282e..21d8a51b 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -4,7 +4,11 @@ from conditions import ConditionController from registrasion.models import commerce from registrasion.models import conditions +from django.db.models import Case +from django.db.models import Q from django.db.models import Sum +from django.db.models import Value +from django.db.models import When class DiscountAndQuantity(object): @@ -43,6 +47,62 @@ def available_discounts(user, categories, products): and products. The discounts also list the available quantity for this user, not including products that are pending purchase. ''' + + + filtered_clauses = _filtered_discounts(user, categories, products) + + discounts = [] + + # Markers so that we don't need to evaluate given conditions more than once + accepted_discounts = set() + failed_discounts = set() + + for clause in filtered_clauses: + discount = clause.discount + cond = ConditionController.for_condition(discount) + + past_use_count = discount.past_use_count + + # TODO: add test case -- + # discount covers 2x prod_1 and 1x prod_2 + # add 1x prod_2 + # add 1x prod_1 + # checkout + # discount should be available for prod_1 + + if past_use_count >= clause.quantity: + # This clause has exceeded its use count + pass + elif discount not in failed_discounts: + # This clause is still available + is_accepted = discount in accepted_discounts + if is_accepted or cond.is_met(user, filtered=True): + # This clause is valid for this user + discounts.append(DiscountAndQuantity( + discount=discount, + clause=clause, + quantity=clause.quantity - past_use_count, + )) + accepted_discounts.add(discount) + else: + # This clause is not valid for this user + failed_discounts.add(discount) + return discounts + + +def _filtered_discounts(user, categories, products): + ''' + + Returns: + Sequence[discountbase]: All discounts that passed the filter function. + + ''' + + types = list(ConditionController._controllers()) + discounttypes = [ + i for i in types if issubclass(i, conditions.DiscountBase) + ] + # discounts that match provided categories category_discounts = conditions.DiscountForCategory.objects.filter( category__in=categories @@ -67,51 +127,56 @@ def available_discounts(user, categories, products): "category", ) + valid_discounts = conditions.DiscountBase.objects.filter( + Q(discountforproduct__in=product_discounts) | + Q(discountforcategory__in=all_category_discounts) + ) + + all_subsets = [] + + for discounttype in discounttypes: + discounts = discounttype.objects.filter(id__in=valid_discounts) + ctrl = ConditionController.for_type(discounttype) + discounts = ctrl.pre_filter(discounts, user) + discounts = _annotate_with_past_uses(discounts, user) + all_subsets.append(discounts) + + filtered_discounts = list(itertools.chain(*all_subsets)) + + # Map from discount key to itself (contains annotations added by filter) + from_filter = dict((i.id, i) for i in filtered_discounts) + # The set of all potential discounts - potential_discounts = set(itertools.chain( - product_discounts, - all_category_discounts, + discount_clauses = set(itertools.chain( + product_discounts.filter(discount__in=filtered_discounts), + all_category_discounts.filter(discount__in=filtered_discounts), )) - discounts = [] + # Replace discounts with the filtered ones + # These are the correct subclasses (saves query later on), and have + # correct annotations from filters if necessary. + for clause in discount_clauses: + clause.discount = from_filter[clause.discount.id] - # Markers so that we don't need to evaluate given conditions more than once - accepted_discounts = set() - failed_discounts = set() + return discount_clauses - for discount in potential_discounts: - real_discount = conditions.DiscountBase.objects.get_subclass( - pk=discount.discount.pk, - ) - cond = ConditionController.for_condition(real_discount) - # Count the past uses of the given discount item. - # If this user has exceeded the limit for the clause, this clause - # is not available any more. - past_uses = commerce.DiscountItem.objects.filter( - cart__user=user, - cart__status=commerce.Cart.STATUS_PAID, # Only past carts count - discount=real_discount, - ) - agg = past_uses.aggregate(Sum("quantity")) - past_use_count = agg["quantity__sum"] - if past_use_count is None: - past_use_count = 0 +def _annotate_with_past_uses(queryset, user): + ''' Annotates the queryset with a usage count for that discount by the + given user. ''' - if past_use_count >= discount.quantity: - # This clause has exceeded its use count - pass - elif real_discount not in failed_discounts: - # This clause is still available - if real_discount in accepted_discounts or cond.is_met(user): - # This clause is valid for this user - discounts.append(DiscountAndQuantity( - discount=real_discount, - clause=discount, - quantity=discount.quantity - past_use_count, - )) - accepted_discounts.add(real_discount) - else: - # This clause is not valid for this user - failed_discounts.add(real_discount) - return discounts + past_use_quantity = When( + ( + Q(discountitem__cart__user=user) & + Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) + ), + then="discountitem__quantity", + ) + + past_use_quantity_or_zero = Case( + past_use_quantity, + default=Value(0), + ) + + queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) + return queryset diff --git a/registrasion/tests/test_ceilings.py b/registrasion/tests/test_ceilings.py index 877556d0..1273a071 100644 --- a/registrasion/tests/test_ceilings.py +++ b/registrasion/tests/test_ceilings.py @@ -6,6 +6,8 @@ from django.core.exceptions import ValidationError from controller_helpers import TestingCartController from test_cart import RegistrationCartTestCase +from registrasion.controllers.discount import available_discounts +from registrasion.controllers.product import ProductController from registrasion.models import commerce from registrasion.models import conditions @@ -135,6 +137,39 @@ class CeilingsTestCases(RegistrationCartTestCase): with self.assertRaises(ValidationError): first_cart.validate_cart() + def test_discount_ceiling_aggregates_products(self): + # Create two carts, add 1xprod_1 to each. Ceiling should disappear + # after second. + self.make_discount_ceiling( + "Multi-product limit discount ceiling", + limit=2, + ) + for i in xrange(2): + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + cart.next_cart() + + discounts = available_discounts(self.USER_1, [], [self.PROD_1]) + + self.assertEqual(0, len(discounts)) + + def test_flag_ceiling_aggregates_products(self): + # Create two carts, add 1xprod_1 to each. Ceiling should disappear + # after second. + self.make_ceiling("Multi-product limit ceiling", limit=2) + + for i in xrange(2): + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + cart.next_cart() + + products = ProductController.available_products( + self.USER_1, + products=[self.PROD_1], + ) + + self.assertEqual(0, len(products)) + def test_items_released_from_ceiling_by_refund(self): self.make_ceiling("Limit ceiling", limit=1) diff --git a/setup.cfg b/setup.cfg index 290fdb45..9b05640c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [flake8] -exclude = registrasion/migrations/*, build/*, docs/* +exclude = registrasion/migrations/*, build/*, docs/*, dist/* From 145fd057aca06e846f55d33c23c79dd0249a6569 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 12:13:42 +1000 Subject: [PATCH 02/15] Breaks out flag-handling code into flag.py and FlagController --- registrasion/controllers/cart.py | 9 +- registrasion/controllers/conditions.py | 259 +------------------------ registrasion/controllers/flag.py | 257 ++++++++++++++++++++++++ registrasion/controllers/product.py | 7 +- 4 files changed, 268 insertions(+), 264 deletions(-) create mode 100644 registrasion/controllers/flag.py diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index e7282e9a..3eb4ce37 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -15,9 +15,10 @@ from registrasion.models import commerce from registrasion.models import conditions from registrasion.models import inventory -from category import CategoryController -from conditions import ConditionController -from product import ProductController +from .category import CategoryController +from .conditions import ConditionController +from .flag import FlagController +from .product import ProductController def _modifies_cart(func): @@ -185,7 +186,7 @@ class CartController(object): )) # Test the flag conditions - errs = ConditionController.test_flags( + errs = FlagController.test_flags( self.cart.user, product_quantities=product_quantities, ) diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index 291a573a..c0847b10 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -18,74 +18,7 @@ from registrasion.models import inventory -ConditionAndRemainder = namedtuple( - "ConditionAndRemainder", - ( - "condition", - "remainder", - ), -) - - -_FlagCounter = namedtuple( - "_FlagCounter", - ( - "products", - "categories", - ), -) - - -_ConditionsCount = namedtuple( - "ConditionsCount", - ( - "dif", - "eit", - ), -) - - -class FlagCounter(_FlagCounter): - - @classmethod - def count(cls): - # Get the count of how many conditions should exist per product - flagbases = conditions.FlagBase.objects - - types = ( - conditions.FlagBase.ENABLE_IF_TRUE, - conditions.FlagBase.DISABLE_IF_FALSE, - ) - keys = ("eit", "dif") - flags = [ - flagbases.filter( - condition=condition_type - ).values( - 'products', 'categories' - ).annotate( - count=Count('id') - ) - for condition_type in types - ] - - cats = defaultdict(lambda: defaultdict(int)) - prod = defaultdict(lambda: defaultdict(int)) - - for key, flagcounts in zip(keys, flags): - for row in flagcounts: - if row["products"] is not None: - prod[row["products"]][key] = row["count"] - if row["categories"] is not None: - cats[row["categories"]][key] = row["count"] - - return cls(products=prod, categories=cats) - - def get(self, product): - p = self.products[product.id] - c = self.categories[product.category.id] - eit = p["eit"] + c["eit"] - dif = p["dif"] + c["dif"] - return _ConditionsCount(dif=dif, eit=eit) +_BIG_QUANTITY = 99999999 # A big quantity class ConditionController(object): @@ -120,184 +53,6 @@ class ConditionController(object): except KeyError: return ConditionController() - SINGLE = True - PLURAL = False - NONE = True - SOME = False - MESSAGE = { - NONE: { - SINGLE: - "%(items)s is no longer available to you", - PLURAL: - "%(items)s are no longer available to you", - }, - SOME: { - SINGLE: - "Only %(remainder)d of the following item remains: %(items)s", - PLURAL: - "Only %(remainder)d of the following items remain: %(items)s" - }, - } - - @classmethod - def test_flags( - cls, user, products=None, product_quantities=None): - ''' Evaluates all of the flag conditions on the given products. - - If `product_quantities` is supplied, the condition is only met if it - will permit the sum of the product quantities for all of the products - it covers. Otherwise, it will be met if at least one item can be - accepted. - - If all flag conditions pass, an empty list is returned, otherwise - a list is returned containing all of the products that are *not - enabled*. ''' - - if products is not None and product_quantities is not None: - raise ValueError("Please specify only products or " - "product_quantities") - elif products is None: - products = set(i[0] for i in product_quantities) - quantities = dict((product, quantity) - for product, quantity in product_quantities) - elif product_quantities is None: - products = set(products) - quantities = {} - - if products: - # Simplify the query. - all_conditions = cls._filtered_flags(user, products) - else: - all_conditions = [] - - # All disable-if-false conditions on a product need to be met - do_not_disable = defaultdict(lambda: True) - # At least one enable-if-true condition on a product must be met - do_enable = defaultdict(lambda: False) - # (if either sort of condition is present) - - # Count the number of conditions for a product - dif_count = defaultdict(int) - eit_count = defaultdict(int) - - messages = {} - - for condition in all_conditions: - cond = cls.for_condition(condition) - remainder = cond.user_quantity_remaining(user, filtered=True) - - # Get all products covered by this condition, and the products - # from the categories covered by this condition - cond_products = condition.products.all() - from_category = inventory.Product.objects.filter( - category__in=condition.categories.all(), - ).all() - all_products = cond_products | from_category - all_products = all_products.select_related("category") - # Remove the products that we aren't asking about - all_products = [ - product - for product in all_products - if product in products - ] - - if quantities: - consumed = sum(quantities[i] for i in all_products) - else: - consumed = 1 - met = consumed <= remainder - - if not met: - items = ", ".join(str(product) for product in all_products) - base = cls.MESSAGE[remainder == 0][len(all_products) == 1] - message = base % {"items": items, "remainder": remainder} - - for product in all_products: - if condition.is_disable_if_false: - do_not_disable[product] &= met - dif_count[product] += 1 - else: - do_enable[product] |= met - eit_count[product] += 1 - - if not met and product not in messages: - messages[product] = message - - total_flags = FlagCounter.count() - - valid = {} - - # the problem is that now, not every condition falls into - # do_not_disable or do_enable ''' - # You should look into this, chris :) - - for product in products: - if quantities: - if quantities[product] == 0: - continue - - f = total_flags.get(product) - if f.dif > 0 and f.dif != dif_count[product]: - do_not_disable[product] = False - if product not in messages: - messages[product] = "Some disable-if-false " \ - "conditions were not met" - if f.eit > 0 and product not in do_enable: - do_enable[product] = False - if product not in messages: - messages[product] = "Some enable-if-true " \ - "conditions were not met" - - for product in itertools.chain(do_not_disable, do_enable): - f = total_flags.get(product) - if product in do_enable: - # If there's an enable-if-true, we need need of those met too. - # (do_not_disable will default to true otherwise) - valid[product] = do_not_disable[product] and do_enable[product] - elif product in do_not_disable: - # If there's a disable-if-false condition, all must be met - valid[product] = do_not_disable[product] - - error_fields = [ - (product, messages[product]) - for product in valid if not valid[product] - ] - - return error_fields - - @classmethod - def _filtered_flags(cls, user, products): - ''' - - Returns: - Sequence[flagbase]: All flags that passed the filter function. - - ''' - - types = list(ConditionController._controllers()) - flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] - - # Get all flags for the products and categories. - prods = ( - product.flagbase_set.all() - for product in products - ) - cats = ( - category.flagbase_set.all() - for category in set(product.category for product in products) - ) - all_flags = reduce(operator.or_, itertools.chain(prods, cats)) - - all_subsets = [] - - for flagtype in flagtypes: - flags = flagtype.objects.filter(id__in=all_flags) - ctrl = ConditionController.for_type(flagtype) - flags = ctrl.pre_filter(flags, user) - all_subsets.append(flags) - - return itertools.chain(*all_subsets) - @classmethod def pre_filter(cls, queryset, user): ''' Returns only the flag conditions that might be available for this @@ -369,16 +124,6 @@ class IsMetByFilter(object): return self.passes_filter(user) - carts = commerce.Cart.objects.filter(user=user) - carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) - enabling_products = inventory.Product.objects.filter( - category=self.condition.enabling_category, - ) - products_count = commerce.ProductItem.objects.filter( - cart__in=carts, - product__in=enabling_products, - ).count() - return products_count > 0 class RemainderSetByFilter(object): @@ -491,7 +236,7 @@ class TimeOrStockLimitFlagController( # Calculate category lines cat_items = F('categories__product__productitem__product__category') reserved_category_products = ( - Q(categories=cat_items) & + Q(categories=F('categories__product__productitem__product__category')) & Q(categories__product__productitem__cart__in=reserved_carts) ) diff --git a/registrasion/controllers/flag.py b/registrasion/controllers/flag.py new file mode 100644 index 00000000..40901931 --- /dev/null +++ b/registrasion/controllers/flag.py @@ -0,0 +1,257 @@ +import itertools +import operator + +from collections import defaultdict +from collections import namedtuple +from django.db.models import Count + +from .conditions import ConditionController + +from registrasion.models import conditions +from registrasion.models import inventory + + +class FlagController(object): + + SINGLE = True + PLURAL = False + NONE = True + SOME = False + MESSAGE = { + NONE: { + SINGLE: + "%(items)s is no longer available to you", + PLURAL: + "%(items)s are no longer available to you", + }, + SOME: { + SINGLE: + "Only %(remainder)d of the following item remains: %(items)s", + PLURAL: + "Only %(remainder)d of the following items remain: %(items)s" + }, + } + + @classmethod + def test_flags( + cls, user, products=None, product_quantities=None): + ''' Evaluates all of the flag conditions on the given products. + + If `product_quantities` is supplied, the condition is only met if it + will permit the sum of the product quantities for all of the products + it covers. Otherwise, it will be met if at least one item can be + accepted. + + If all flag conditions pass, an empty list is returned, otherwise + a list is returned containing all of the products that are *not + enabled*. ''' + + print "GREPME: test_flags()" + + if products is not None and product_quantities is not None: + raise ValueError("Please specify only products or " + "product_quantities") + elif products is None: + products = set(i[0] for i in product_quantities) + quantities = dict((product, quantity) + for product, quantity in product_quantities) + elif product_quantities is None: + products = set(products) + quantities = {} + + if products: + # Simplify the query. + all_conditions = cls._filtered_flags(user, products) + else: + all_conditions = [] + + # All disable-if-false conditions on a product need to be met + do_not_disable = defaultdict(lambda: True) + # At least one enable-if-true condition on a product must be met + do_enable = defaultdict(lambda: False) + # (if either sort of condition is present) + + # Count the number of conditions for a product + dif_count = defaultdict(int) + eit_count = defaultdict(int) + + messages = {} + + for condition in all_conditions: + cond = ConditionController.for_condition(condition) + remainder = cond.user_quantity_remaining(user, filtered=True) + + # Get all products covered by this condition, and the products + # from the categories covered by this condition + cond_products = condition.products.all() + from_category = inventory.Product.objects.filter( + category__in=condition.categories.all(), + ).all() + all_products = cond_products | from_category + all_products = all_products.select_related("category") + # Remove the products that we aren't asking about + all_products = [ + product + for product in all_products + if product in products + ] + + if quantities: + consumed = sum(quantities[i] for i in all_products) + else: + consumed = 1 + met = consumed <= remainder + + if not met: + items = ", ".join(str(product) for product in all_products) + base = cls.MESSAGE[remainder == 0][len(all_products) == 1] + message = base % {"items": items, "remainder": remainder} + + for product in all_products: + if condition.is_disable_if_false: + do_not_disable[product] &= met + dif_count[product] += 1 + else: + do_enable[product] |= met + eit_count[product] += 1 + + if not met and product not in messages: + messages[product] = message + + total_flags = FlagCounter.count() + + valid = {} + + # the problem is that now, not every condition falls into + # do_not_disable or do_enable ''' + # You should look into this, chris :) + + for product in products: + if quantities: + if quantities[product] == 0: + continue + + f = total_flags.get(product) + if f.dif > 0 and f.dif != dif_count[product]: + do_not_disable[product] = False + if product not in messages: + messages[product] = "Some disable-if-false " \ + "conditions were not met" + if f.eit > 0 and product not in do_enable: + do_enable[product] = False + if product not in messages: + messages[product] = "Some enable-if-true " \ + "conditions were not met" + + for product in itertools.chain(do_not_disable, do_enable): + f = total_flags.get(product) + if product in do_enable: + # If there's an enable-if-true, we need need of those met too. + # (do_not_disable will default to true otherwise) + valid[product] = do_not_disable[product] and do_enable[product] + elif product in do_not_disable: + # If there's a disable-if-false condition, all must be met + valid[product] = do_not_disable[product] + + error_fields = [ + (product, messages[product]) + for product in valid if not valid[product] + ] + + return error_fields + + @classmethod + def _filtered_flags(cls, user, products): + ''' + + Returns: + Sequence[flagbase]: All flags that passed the filter function. + + ''' + + types = list(ConditionController._controllers()) + flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] + + # Get all flags for the products and categories. + prods = ( + product.flagbase_set.all() + for product in products + ) + cats = ( + category.flagbase_set.all() + for category in set(product.category for product in products) + ) + all_flags = reduce(operator.or_, itertools.chain(prods, cats)) + + all_subsets = [] + + for flagtype in flagtypes: + flags = flagtype.objects.filter(id__in=all_flags) + ctrl = ConditionController.for_type(flagtype) + flags = ctrl.pre_filter(flags, user) + all_subsets.append(flags) + + return itertools.chain(*all_subsets) + + +ConditionAndRemainder = namedtuple( + "ConditionAndRemainder", + ( + "condition", + "remainder", + ), +) + + +_FlagCounter = namedtuple( + "_FlagCounter", + ( + "products", + "categories", + ), +) + + +_ConditionsCount = namedtuple( + "ConditionsCount", + ( + "dif", + "eit", + ), +) + + +class FlagCounter(_FlagCounter): + + @classmethod + def count(cls): + # Get the count of how many conditions should exist per product + flagbases = conditions.FlagBase.objects + + types = (conditions.FlagBase.ENABLE_IF_TRUE, conditions.FlagBase.DISABLE_IF_FALSE) + keys = ("eit", "dif") + flags = [ + flagbases.filter(condition=condition_type + ).values('products', 'categories' + ).annotate(count=Count('id')) + for condition_type in types + ] + + cats = defaultdict(lambda: defaultdict(int)) + prod = defaultdict(lambda: defaultdict(int)) + + for key, flagcounts in zip(keys, flags): + for row in flagcounts: + if row["products"] is not None: + prod[row["products"]][key] = row["count"] + if row["categories"] is not None: + cats[row["categories"]][key] = row["count"] + + return cls(products=prod, categories=cats) + + def get(self, product): + p = self.products[product.id] + c = self.categories[product.category.id] + eit = p["eit"] + c["eit"] + dif = p["dif"] + c["dif"] + return _ConditionsCount(dif=dif, eit=eit) diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 99618ded..09f66bb3 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -4,8 +4,9 @@ from django.db.models import Sum from registrasion.models import commerce from registrasion.models import inventory -from category import CategoryController -from conditions import ConditionController +from .category import CategoryController +from .conditions import ConditionController +from .flag import FlagController class ProductController(object): @@ -46,7 +47,7 @@ class ProductController(object): if cls(product).user_quantity_remaining(user) > 0 ) - failed_and_messages = ConditionController.test_flags( + failed_and_messages = FlagController.test_flags( user, products=passed_limits ) failed_conditions = set(i[0] for i in failed_and_messages) From 71de0df5dc5116d3154aedf507323e7e73c6a797 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 12:20:36 +1000 Subject: [PATCH 03/15] Makes DiscountController a class and puts available_discounts inside it --- registrasion/controllers/cart.py | 3 +- registrasion/controllers/discount.py | 215 +++++++++++++-------------- registrasion/tests/test_ceilings.py | 4 +- registrasion/tests/test_discount.py | 36 ++--- registrasion/views.py | 4 +- 5 files changed, 128 insertions(+), 134 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 3eb4ce37..597dd47b 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -17,6 +17,7 @@ from registrasion.models import inventory from .category import CategoryController from .conditions import ConditionController +from .discount import DiscountController from .flag import FlagController from .product import ProductController @@ -377,7 +378,7 @@ class CartController(object): ) products = [i.product for i in product_items] - discounts = discount.available_discounts(self.cart.user, [], products) + discounts = DiscountController.available_discounts(self.cart.user, [], products) # The highest-value discounts will apply to the highest-value # products first. diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index 21d8a51b..c04ed42d 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -42,141 +42,134 @@ class DiscountAndQuantity(object): ) -def available_discounts(user, categories, products): - ''' Returns all discounts available to this user for the given categories - and products. The discounts also list the available quantity for this user, - not including products that are pending purchase. ''' +class DiscountController(object): + + @classmethod + def available_discounts(cls, user, categories, products): + ''' Returns all discounts available to this user for the given categories + and products. The discounts also list the available quantity for this user, + not including products that are pending purchase. ''' + filtered_clauses = cls._filtered_discounts(user, categories, products) - filtered_clauses = _filtered_discounts(user, categories, products) + discounts = [] - discounts = [] + # Markers so that we don't need to evaluate given conditions more than once + accepted_discounts = set() + failed_discounts = set() - # Markers so that we don't need to evaluate given conditions more than once - accepted_discounts = set() - failed_discounts = set() + for clause in filtered_clauses: + discount = clause.discount + cond = ConditionController.for_condition(discount) - for clause in filtered_clauses: - discount = clause.discount - cond = ConditionController.for_condition(discount) - - past_use_count = discount.past_use_count - - # TODO: add test case -- - # discount covers 2x prod_1 and 1x prod_2 - # add 1x prod_2 - # add 1x prod_1 - # checkout - # discount should be available for prod_1 - - if past_use_count >= clause.quantity: - # This clause has exceeded its use count - pass - elif discount not in failed_discounts: - # This clause is still available - is_accepted = discount in accepted_discounts - if is_accepted or cond.is_met(user, filtered=True): - # This clause is valid for this user - discounts.append(DiscountAndQuantity( - discount=discount, - clause=clause, - quantity=clause.quantity - past_use_count, - )) - accepted_discounts.add(discount) - else: - # This clause is not valid for this user - failed_discounts.add(discount) - return discounts + past_use_count = discount.past_use_count -def _filtered_discounts(user, categories, products): - ''' + if past_use_count >= clause.quantity: + # This clause has exceeded its use count + pass + elif discount not in failed_discounts: + # This clause is still available + if discount in accepted_discounts or cond.is_met(user, filtered=True): + # This clause is valid for this user + discounts.append(DiscountAndQuantity( + discount=discount, + clause=clause, + quantity=clause.quantity - past_use_count, + )) + accepted_discounts.add(discount) + else: + # This clause is not valid for this user + failed_discounts.add(discount) + return discounts - Returns: - Sequence[discountbase]: All discounts that passed the filter function. + @classmethod + def _filtered_discounts(cls, user, categories, products): + ''' - ''' + Returns: + Sequence[discountbase]: All discounts that passed the filter function. - types = list(ConditionController._controllers()) - discounttypes = [ - i for i in types if issubclass(i, conditions.DiscountBase) - ] + ''' - # discounts that match provided categories - category_discounts = conditions.DiscountForCategory.objects.filter( - category__in=categories - ) - # discounts that match provided products - product_discounts = conditions.DiscountForProduct.objects.filter( - product__in=products - ) - # discounts that match categories for provided products - product_category_discounts = conditions.DiscountForCategory.objects.filter( - category__in=(product.category for product in products) - ) - # (Not relevant: discounts that match products in provided categories) + types = list(ConditionController._controllers()) + discounttypes = [i for i in types if issubclass(i, conditions.DiscountBase)] - product_discounts = product_discounts.select_related( - "product", - "product__category", - ) + # discounts that match provided categories + category_discounts = conditions.DiscountForCategory.objects.filter( + category__in=categories + ) + # discounts that match provided products + product_discounts = conditions.DiscountForProduct.objects.filter( + product__in=products + ) + # discounts that match categories for provided products + product_category_discounts = conditions.DiscountForCategory.objects.filter( + category__in=(product.category for product in products) + ) + # (Not relevant: discounts that match products in provided categories) - all_category_discounts = category_discounts | product_category_discounts - all_category_discounts = all_category_discounts.select_related( - "category", - ) + product_discounts = product_discounts.select_related( + "product", + "product__category", + ) - valid_discounts = conditions.DiscountBase.objects.filter( - Q(discountforproduct__in=product_discounts) | - Q(discountforcategory__in=all_category_discounts) - ) + all_category_discounts = category_discounts | product_category_discounts + all_category_discounts = all_category_discounts.select_related( + "category", + ) - all_subsets = [] + valid_discounts = conditions.DiscountBase.objects.filter( + Q(discountforproduct__in=product_discounts) | + Q(discountforcategory__in=all_category_discounts) + ) - for discounttype in discounttypes: - discounts = discounttype.objects.filter(id__in=valid_discounts) - ctrl = ConditionController.for_type(discounttype) - discounts = ctrl.pre_filter(discounts, user) - discounts = _annotate_with_past_uses(discounts, user) - all_subsets.append(discounts) + all_subsets = [] - filtered_discounts = list(itertools.chain(*all_subsets)) + for discounttype in discounttypes: + discounts = discounttype.objects.filter(id__in=valid_discounts) + ctrl = ConditionController.for_type(discounttype) + discounts = ctrl.pre_filter(discounts, user) + discounts = cls._annotate_with_past_uses(discounts, user) + all_subsets.append(discounts) - # Map from discount key to itself (contains annotations added by filter) - from_filter = dict((i.id, i) for i in filtered_discounts) + filtered_discounts = list(itertools.chain(*all_subsets)) - # The set of all potential discounts - discount_clauses = set(itertools.chain( - product_discounts.filter(discount__in=filtered_discounts), - all_category_discounts.filter(discount__in=filtered_discounts), - )) + # Map from discount key to itself (contains annotations added by filter) + from_filter = dict((i.id, i) for i in filtered_discounts) - # Replace discounts with the filtered ones - # These are the correct subclasses (saves query later on), and have - # correct annotations from filters if necessary. - for clause in discount_clauses: - clause.discount = from_filter[clause.discount.id] + # The set of all potential discounts + discount_clauses = set(itertools.chain( + product_discounts.filter(discount__in=filtered_discounts), + all_category_discounts.filter(discount__in=filtered_discounts), + )) - return discount_clauses + # Replace discounts with the filtered ones + # These are the correct subclasses (saves query later on), and have + # correct annotations from filters if necessary. + for clause in discount_clauses: + clause.discount = from_filter[clause.discount.id] + return discount_clauses -def _annotate_with_past_uses(queryset, user): - ''' Annotates the queryset with a usage count for that discount by the - given user. ''' + @classmethod + def _annotate_with_past_uses(cls, queryset, user): + ''' Annotates the queryset with a usage count for that discount by the + given user. ''' - past_use_quantity = When( - ( - Q(discountitem__cart__user=user) & - Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) - ), - then="discountitem__quantity", - ) + past_use_quantity = When( + ( + Q(discountitem__cart__user=user) & + Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) + ), + then="discountitem__quantity", + ) - past_use_quantity_or_zero = Case( - past_use_quantity, - default=Value(0), - ) + past_use_quantity_or_zero = Case( + past_use_quantity, + default=Value(0), + ) - queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) - return queryset + queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) + return queryset diff --git a/registrasion/tests/test_ceilings.py b/registrasion/tests/test_ceilings.py index 1273a071..bd28b598 100644 --- a/registrasion/tests/test_ceilings.py +++ b/registrasion/tests/test_ceilings.py @@ -6,7 +6,7 @@ from django.core.exceptions import ValidationError from controller_helpers import TestingCartController from test_cart import RegistrationCartTestCase -from registrasion.controllers.discount import available_discounts +from registrasion.controllers.discount import DiscountController from registrasion.controllers.product import ProductController from registrasion.models import commerce from registrasion.models import conditions @@ -149,7 +149,7 @@ class CeilingsTestCases(RegistrationCartTestCase): cart.add_to_cart(self.PROD_1, 1) cart.next_cart() - discounts = available_discounts(self.USER_1, [], [self.PROD_1]) + discounts = DiscountController.available_discounts(self.USER_1, [], [self.PROD_1]) self.assertEqual(0, len(discounts)) diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index 71d09d97..fd8302de 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -4,7 +4,7 @@ from decimal import Decimal from registrasion.models import commerce from registrasion.models import conditions -from registrasion.controllers import discount +from registrasion.controllers.discount import DiscountController from controller_helpers import TestingCartController from test_cart import RegistrationCartTestCase @@ -243,22 +243,22 @@ class DiscountTestCase(RegistrationCartTestCase): # The discount is applied. self.assertEqual(1, len(discount_items)) - # Tests for the discount.available_discounts enumerator + # Tests for the DiscountController.available_discounts enumerator def test_enumerate_no_discounts_for_no_input(self): - discounts = discount.available_discounts(self.USER_1, [], []) + discounts = DiscountController.available_discounts(self.USER_1, [], []) self.assertEqual(0, len(discounts)) def test_enumerate_no_discounts_if_condition_not_met(self): self.add_discount_prod_1_includes_cat_2(quantity=1) - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_3], ) self.assertEqual(0, len(discounts)) - discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) self.assertEqual(0, len(discounts)) def test_category_discount_appears_once_if_met_twice(self): @@ -267,7 +267,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [self.CAT_2], [self.PROD_3], @@ -280,7 +280,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) self.assertEqual(1, len(discounts)) def test_category_discount_appears_with_product(self): @@ -289,7 +289,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_3], @@ -302,7 +302,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_3, self.PROD_4] @@ -315,7 +315,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_2], @@ -328,7 +328,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts(self.USER_1, [self.CAT_1], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_1], []) self.assertEqual(0, len(discounts)) def test_discount_quantity_is_correct_before_first_purchase(self): @@ -338,7 +338,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity - discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) self.assertEqual(2, discounts[0].quantity) cart.next_cart() @@ -349,21 +349,21 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity - discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) self.assertEqual(1, discounts[0].quantity) cart.next_cart() def test_discount_is_gone_after_quantity_exhausted(self): self.test_discount_quantity_is_correct_after_first_purchase() - discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) self.assertEqual(0, len(discounts)) def test_product_discount_enabled_twice_appears_twice(self): self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_3, self.PROD_4], @@ -374,7 +374,7 @@ class DiscountTestCase(RegistrationCartTestCase): self.add_discount_prod_1_includes_prod_2(quantity=2) cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_2], @@ -388,7 +388,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart.next_cart() - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_2], @@ -398,7 +398,7 @@ class DiscountTestCase(RegistrationCartTestCase): cart.cart.status = commerce.Cart.STATUS_RELEASED cart.cart.save() - discounts = discount.available_discounts( + discounts = DiscountController.available_discounts( self.USER_1, [], [self.PROD_2], diff --git a/registrasion/views.py b/registrasion/views.py index f10de90f..bdfcf551 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -5,7 +5,7 @@ from registrasion import util from registrasion.models import commerce from registrasion.models import inventory from registrasion.models import people -from registrasion.controllers import discount +from registrasion.controllers.discount import DiscountController from registrasion.controllers.cart import CartController from registrasion.controllers.credit_note import CreditNoteController from registrasion.controllers.invoice import InvoiceController @@ -427,7 +427,7 @@ def _handle_products(request, category, products, prefix): ) handled = False if products_form.errors else True - discounts = discount.available_discounts(request.user, [], products) + discounts = DiscountController.available_discounts(request.user, [], products) return products_form, discounts, handled From 162db248178173c53d5ed3dec185eac884cb4204 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 12:39:20 +1000 Subject: [PATCH 04/15] Flake8 fixes --- registrasion/controllers/cart.py | 7 +++-- registrasion/controllers/conditions.py | 9 ++---- registrasion/controllers/discount.py | 33 +++++++++++++------- registrasion/controllers/flag.py | 15 ++++++--- registrasion/controllers/product.py | 1 - registrasion/tests/test_cart.py | 2 +- registrasion/tests/test_ceilings.py | 6 +++- registrasion/tests/test_discount.py | 42 +++++++++++++++++++++----- registrasion/views.py | 6 +++- 9 files changed, 86 insertions(+), 35 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 597dd47b..47340c15 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -1,6 +1,5 @@ import collections import datetime -import discount import functools import itertools @@ -378,7 +377,11 @@ class CartController(object): ) products = [i.product for i in product_items] - discounts = DiscountController.available_discounts(self.cart.user, [], products) + discounts = DiscountController.available_discounts( + self.cart.user, + [], + products, + ) # The highest-value discounts will apply to the highest-value # products first. diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index c0847b10..0a2b3c4a 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -1,11 +1,6 @@ import itertools -import operator - -from collections import defaultdict -from collections import namedtuple from django.db.models import Case -from django.db.models import Count from django.db.models import F, Q from django.db.models import Sum from django.db.models import Value @@ -234,9 +229,9 @@ class TimeOrStockLimitFlagController( reserved_carts = cls._relevant_carts(user) # Calculate category lines - cat_items = F('categories__product__productitem__product__category') + item_cats = F('categories__product__productitem__product__category') reserved_category_products = ( - Q(categories=F('categories__product__productitem__product__category')) & + Q(categories=item_cats) & Q(categories__product__productitem__cart__in=reserved_carts) ) diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index c04ed42d..1c7fa59f 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -46,16 +46,17 @@ class DiscountController(object): @classmethod def available_discounts(cls, user, categories, products): - ''' Returns all discounts available to this user for the given categories - and products. The discounts also list the available quantity for this user, - not including products that are pending purchase. ''' + ''' Returns all discounts available to this user for the given + categories and products. The discounts also list the available quantity + for this user, not including products that are pending purchase. ''' filtered_clauses = cls._filtered_discounts(user, categories, products) discounts = [] - # Markers so that we don't need to evaluate given conditions more than once + # Markers so that we don't need to evaluate given conditions + # more than once accepted_discounts = set() failed_discounts = set() @@ -71,7 +72,8 @@ class DiscountController(object): pass elif discount not in failed_discounts: # This clause is still available - if discount in accepted_discounts or cond.is_met(user, filtered=True): + is_accepted = discount in accepted_discounts + if is_accepted or cond.is_met(user, filtered=True): # This clause is valid for this user discounts.append(DiscountAndQuantity( discount=discount, @@ -89,12 +91,15 @@ class DiscountController(object): ''' Returns: - Sequence[discountbase]: All discounts that passed the filter function. + Sequence[discountbase]: All discounts that passed the filter + function. ''' types = list(ConditionController._controllers()) - discounttypes = [i for i in types if issubclass(i, conditions.DiscountBase)] + discounttypes = [ + i for i in types if issubclass(i, conditions.DiscountBase) + ] # discounts that match provided categories category_discounts = conditions.DiscountForCategory.objects.filter( @@ -105,7 +110,8 @@ class DiscountController(object): product__in=products ) # discounts that match categories for provided products - product_category_discounts = conditions.DiscountForCategory.objects.filter( + product_category_discounts = conditions.DiscountForCategory.objects + product_category_discounts = product_category_discounts.filter( category__in=(product.category for product in products) ) # (Not relevant: discounts that match products in provided categories) @@ -115,7 +121,9 @@ class DiscountController(object): "product__category", ) - all_category_discounts = category_discounts | product_category_discounts + all_category_discounts = ( + category_discounts | product_category_discounts + ) all_category_discounts = all_category_discounts.select_related( "category", ) @@ -136,7 +144,8 @@ class DiscountController(object): filtered_discounts = list(itertools.chain(*all_subsets)) - # Map from discount key to itself (contains annotations added by filter) + # Map from discount key to itself + # (contains annotations needed in the future) from_filter = dict((i.id, i) for i in filtered_discounts) # The set of all potential discounts @@ -171,5 +180,7 @@ class DiscountController(object): default=Value(0), ) - queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero)) + queryset = queryset.annotate( + past_use_count=Sum(past_use_quantity_or_zero) + ) return queryset diff --git a/registrasion/controllers/flag.py b/registrasion/controllers/flag.py index 40901931..a67a0b94 100644 --- a/registrasion/controllers/flag.py +++ b/registrasion/controllers/flag.py @@ -228,12 +228,19 @@ class FlagCounter(_FlagCounter): # Get the count of how many conditions should exist per product flagbases = conditions.FlagBase.objects - types = (conditions.FlagBase.ENABLE_IF_TRUE, conditions.FlagBase.DISABLE_IF_FALSE) + types = ( + conditions.FlagBase.ENABLE_IF_TRUE, + conditions.FlagBase.DISABLE_IF_FALSE, + ) keys = ("eit", "dif") flags = [ - flagbases.filter(condition=condition_type - ).values('products', 'categories' - ).annotate(count=Count('id')) + flagbases.filter( + condition=condition_type + ).values( + 'products', 'categories' + ).annotate( + count=Count('id') + ) for condition_type in types ] diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 09f66bb3..e1ad9ef7 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -5,7 +5,6 @@ from registrasion.models import commerce from registrasion.models import inventory from .category import CategoryController -from .conditions import ConditionController from .flag import FlagController diff --git a/registrasion/tests/test_cart.py b/registrasion/tests/test_cart.py index 507d5cf7..790c1df9 100644 --- a/registrasion/tests/test_cart.py +++ b/registrasion/tests/test_cart.py @@ -26,7 +26,7 @@ class RegistrationCartTestCase(SetTimeMixin, TestCase): super(RegistrationCartTestCase, self).setUp() def tearDown(self): - if False: + if True: # If you're seeing segfaults in tests, enable this. call_command( 'flush', diff --git a/registrasion/tests/test_ceilings.py b/registrasion/tests/test_ceilings.py index bd28b598..87b54946 100644 --- a/registrasion/tests/test_ceilings.py +++ b/registrasion/tests/test_ceilings.py @@ -149,7 +149,11 @@ class CeilingsTestCases(RegistrationCartTestCase): cart.add_to_cart(self.PROD_1, 1) cart.next_cart() - discounts = DiscountController.available_discounts(self.USER_1, [], [self.PROD_1]) + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [self.PROD_1], + ) self.assertEqual(0, len(discounts)) diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index fd8302de..4b92c81b 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -245,7 +245,11 @@ class DiscountTestCase(RegistrationCartTestCase): # Tests for the DiscountController.available_discounts enumerator def test_enumerate_no_discounts_for_no_input(self): - discounts = DiscountController.available_discounts(self.USER_1, [], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [], + ) self.assertEqual(0, len(discounts)) def test_enumerate_no_discounts_if_condition_not_met(self): @@ -258,7 +262,11 @@ class DiscountTestCase(RegistrationCartTestCase): ) self.assertEqual(0, len(discounts)) - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_2], + [], + ) self.assertEqual(0, len(discounts)) def test_category_discount_appears_once_if_met_twice(self): @@ -280,7 +288,11 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_2], + [], + ) self.assertEqual(1, len(discounts)) def test_category_discount_appears_with_product(self): @@ -328,7 +340,11 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_1, 1) # Enable the discount - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_1], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_1], + [], + ) self.assertEqual(0, len(discounts)) def test_discount_quantity_is_correct_before_first_purchase(self): @@ -338,7 +354,11 @@ class DiscountTestCase(RegistrationCartTestCase): cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_2], + [], + ) self.assertEqual(2, discounts[0].quantity) cart.next_cart() @@ -349,14 +369,22 @@ class DiscountTestCase(RegistrationCartTestCase): cart = TestingCartController.for_user(self.USER_1) cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_2], + [], + ) self.assertEqual(1, discounts[0].quantity) cart.next_cart() def test_discount_is_gone_after_quantity_exhausted(self): self.test_discount_quantity_is_correct_after_first_purchase() - discounts = DiscountController.available_discounts(self.USER_1, [self.CAT_2], []) + discounts = DiscountController.available_discounts( + self.USER_1, + [self.CAT_2], + [], + ) self.assertEqual(0, len(discounts)) def test_product_discount_enabled_twice_appears_twice(self): diff --git a/registrasion/views.py b/registrasion/views.py index bdfcf551..416afe0a 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -427,7 +427,11 @@ def _handle_products(request, category, products, prefix): ) handled = False if products_form.errors else True - discounts = DiscountController.available_discounts(request.user, [], products) + discounts = DiscountController.available_discounts( + request.user, + [], + products, + ) return products_form, discounts, handled From 587e6e20b284e6f2522bf89900c16ea6682e3a67 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 14:01:36 +1000 Subject: [PATCH 05/15] Adds an operations_batch context manager that allows batches of modifying operations to be nested. Closes #44. --- registrasion/controllers/cart.py | 87 +++++++++++++++++++++++------ registrasion/controllers/invoice.py | 6 ++ 2 files changed, 75 insertions(+), 18 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 47340c15..48cec0dd 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -1,4 +1,5 @@ import collections +import contextlib import datetime import functools import itertools @@ -23,12 +24,19 @@ from .product import ProductController def _modifies_cart(func): ''' Decorator that makes the wrapped function raise ValidationError - if we're doing something that could modify the cart. ''' + if we're doing something that could modify the cart. + + It also wraps the execution of this function in a database transaction, + and marks the boundaries of a cart operations batch. + ''' @functools.wraps(func) def inner(self, *a, **k): self._fail_if_cart_is_not_active() - return func(self, *a, **k) + with transaction.atomic(): + with CartController.operations_batch(self.cart.user) as mark: + mark.mark = True # Marker that we've modified the cart + return func(self, *a, **k) return inner @@ -56,13 +64,57 @@ class CartController(object): ) return cls(existing) + + # Marks the carts that are currently in batches + _BATCH_COUNT = collections.defaultdict(int) + _MODIFIED_CARTS = set() + + class _ModificationMarker(object): + pass + + @classmethod + @contextlib.contextmanager + def operations_batch(cls, user): + ''' Marks the boundary for a batch of operations on a user's cart. + + These markers can be nested. Only on exiting the outermost marker will + a batch be ended. + + When a batch is ended, discounts are recalculated, and the cart's + revision is increased. + ''' + + ctrl = cls.for_user(user) + _id = ctrl.cart.id + + cls._BATCH_COUNT[_id] += 1 + try: + success = False + + marker = cls._ModificationMarker() + yield marker + + if hasattr(marker, "mark"): + cls._MODIFIED_CARTS.add(_id) + + success = True + finally: + + cls._BATCH_COUNT[_id] -= 1 + + # Only end on the outermost batch marker, and only if + # it excited cleanly, and a modification occurred + modified = _id in cls._MODIFIED_CARTS + if modified and cls._BATCH_COUNT[_id] == 0 and success: + ctrl._end_batch() + cls._MODIFIED_CARTS.remove(_id) + def _fail_if_cart_is_not_active(self): self.cart.refresh_from_db() if self.cart.status != commerce.Cart.STATUS_ACTIVE: raise ValidationError("You can only amend active carts.") - @_modifies_cart - def extend_reservation(self): + def _autoextend_reservation(self): ''' Updates the cart's time last updated value, which is used to determine whether the cart has reserved the items and discounts it holds. ''' @@ -84,21 +136,26 @@ class CartController(object): self.cart.time_last_updated = timezone.now() self.cart.reservation_duration = max(reservations) - @_modifies_cart - def end_batch(self): + def _end_batch(self): ''' Performs operations that occur occur at the end of a batch of product changes/voucher applications etc. - THIS SHOULD BE PRIVATE + + You need to call this after you've finished modifying the user's cart. + This is normally done by wrapping a block of code using + ``operations_batch``. + ''' - self.recalculate_discounts() - self.extend_reservation() + self.cart.refresh_from_db() + + self._recalculate_discounts() + + self._autoextend_reservation() self.cart.revision += 1 self.cart.save() @_modifies_cart - @transaction.atomic def set_quantities(self, product_quantities): ''' Sets the quantities on each of the products on each of the products specified. Raises an exception (ValidationError) if a limit @@ -140,8 +197,6 @@ class CartController(object): items_in_cart.filter(quantity=0).delete() - self.end_batch() - def _test_limits(self, product_quantities): ''' Tests that the quantity changes we intend to make do not violate the limits and flag conditions imposed on the products. ''' @@ -213,7 +268,6 @@ class CartController(object): # If successful... self.cart.vouchers.add(voucher) - self.end_batch() def _test_voucher(self, voucher): ''' Tests whether this voucher is allowed to be applied to this cart. @@ -331,7 +385,6 @@ class CartController(object): raise ValidationError(errors) @_modifies_cart - @transaction.atomic def fix_simple_errors(self): ''' This attempts to fix the easy errors raised by ValidationError. This includes removing items from the cart that are no longer @@ -363,11 +416,9 @@ class CartController(object): self.set_quantities(zeros) - @_modifies_cart @transaction.atomic - def recalculate_discounts(self): - ''' Calculates all of the discounts available for this product. - ''' + def _recalculate_discounts(self): + ''' Calculates all of the discounts available for this product.''' # Delete the existing entries. commerce.DiscountItem.objects.filter(cart=self.cart).delete() diff --git a/registrasion/controllers/invoice.py b/registrasion/controllers/invoice.py index 62bab0b6..9ef155ef 100644 --- a/registrasion/controllers/invoice.py +++ b/registrasion/controllers/invoice.py @@ -29,6 +29,7 @@ class InvoiceController(ForId, object): If such an invoice does not exist, the cart is validated, and if valid, an invoice is generated.''' + cart.refresh_from_db() try: invoice = commerce.Invoice.objects.exclude( status=commerce.Invoice.STATUS_VOID, @@ -74,6 +75,8 @@ class InvoiceController(ForId, object): def _generate(cls, cart): ''' Generates an invoice for the given cart. ''' + cart.refresh_from_db() + issued = timezone.now() reservation_limit = cart.reservation_duration + cart.time_last_updated # Never generate a due time that is before the issue time @@ -251,6 +254,9 @@ class InvoiceController(ForId, object): def _invoice_matches_cart(self): ''' Returns true if there is no cart, or if the revision of this invoice matches the current revision of the cart. ''' + + self._refresh() + cart = self.invoice.cart if not cart: return True From 76e6206d09876645e12588e1569f95954c0631d6 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 14:46:17 +1000 Subject: [PATCH 06/15] Wraps the guided registration handler in views.py in a batch marker --- registrasion/views.py | 48 ++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/registrasion/views.py b/registrasion/views.py index 416afe0a..e23b57a1 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -181,33 +181,35 @@ def guided_registration(request): attendee.save() return next_step - for category in cats: - products = [ - i for i in available_products - if i.category == category - ] + with CartController.operations_batch(request.user): + for category in cats: + products = [ + i for i in available_products + if i.category == category + ] - prefix = "category_" + str(category.id) - p = _handle_products(request, category, products, prefix) - products_form, discounts, products_handled = p + prefix = "category_" + str(category.id) + p = _handle_products(request, category, products, prefix) + products_form, discounts, products_handled = p - section = GuidedRegistrationSection( - title=category.name, - description=category.description, - discounts=discounts, - form=products_form, - ) + section = GuidedRegistrationSection( + title=category.name, + description=category.description, + discounts=discounts, + form=products_form, + ) - if products: - # This product category has items to show. - sections.append(section) - # Add this to the list of things to show if the form errors. - request.session[SESSION_KEY].append(category.id) + if products: + # This product category has items to show. + sections.append(section) + # Add this to the list of things to show if the form + # errors. + request.session[SESSION_KEY].append(category.id) - if request.method == "POST" and not products_form.errors: - # This is only saved if we pass each form with no errors, - # and if the form actually has products. - attendee.guided_categories_complete.add(category) + if request.method == "POST" and not products_form.errors: + # This is only saved if we pass each form with no + # errors, and if the form actually has products. + attendee.guided_categories_complete.add(category) if sections and request.method == "POST": for section in sections: From 3b5b958b78b63f2494d073aa1cc670a176cc55fd Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 17:19:27 +1000 Subject: [PATCH 07/15] =?UTF-8?q?Makes=20the=20discounts=20section=20from?= =?UTF-8?q?=20=5Fhandle=5Fproducts=20evaluate=20lazily,=20just=20in=20case?= =?UTF-8?q?=20it=E2=80=99s=20never=20displayed=20in=20a=20template=20(thos?= =?UTF-8?q?e=20are=20some=20very=20very=20expensive=20queries=20there).?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- registrasion/util.py | 30 ++++++++++++++++++++++++++++++ registrasion/views.py | 6 +++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/registrasion/util.py b/registrasion/util.py index 7179ceb5..54f56a1e 100644 --- a/registrasion/util.py +++ b/registrasion/util.py @@ -25,3 +25,33 @@ def all_arguments_optional(ntcls): ) return ntcls + + +def lazy(function, *args, **kwargs): + ''' Produces a callable so that functions can be lazily evaluated in + templates. + + Arguments: + + function (callable): The function to call at evaluation time. + + args: Positional arguments, passed directly to ``function``. + + kwargs: Keyword arguments, passed directly to ``function``. + + Return: + + callable: A callable that will evaluate a call to ``function`` with + the specified arguments. + + ''' + + NOT_EVALUATED = object() + retval = [NOT_EVALUATED] + + def evaluate(): + if retval[0] is NOT_EVALUATED: + retval[0] = function(*args, **kwargs) + return retval[0] + + return evaluate diff --git a/registrasion/views.py b/registrasion/views.py index e23b57a1..41c4a0d6 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -429,7 +429,11 @@ def _handle_products(request, category, products, prefix): ) handled = False if products_form.errors else True - discounts = DiscountController.available_discounts( + # Making this a function to lazily evaluate when it's displayed + # in templates. + + discounts = util.lazy( + DiscountController.available_discounts, request.user, [], products, From a79ad3520e60af3415f234c32a0f3120b9ea31a5 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 18:57:55 +1000 Subject: [PATCH 08/15] Puts attach_remainders on ProductController and CategoryController, eliminating the need to query each product and category separately. --- registrasion/controllers/cart.py | 17 ++++-- registrasion/controllers/category.py | 65 ++++++++++++++------- registrasion/controllers/product.py | 84 ++++++++++++++++++---------- 3 files changed, 113 insertions(+), 53 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 48cec0dd..30b36178 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -203,13 +203,17 @@ class CartController(object): errors = [] + # Pre-annotate products + products = [p for (p, q) in product_quantities] + r = ProductController.attach_user_remainders(self.cart.user, products) + with_remainders = dict((p, p) for p in r) + # Test each product limit here for product, quantity in product_quantities: if quantity < 0: errors.append((product, "Value must be zero or greater.")) - prod = ProductController(product) - limit = prod.user_quantity_remaining(self.cart.user) + limit = with_remainders[product].remainder if quantity > limit: errors.append(( @@ -224,10 +228,15 @@ class CartController(object): for product, quantity in product_quantities: by_cat[product.category].append((product, quantity)) + # Pre-annotate categories + r = CategoryController.attach_user_remainders(self.cart.user, by_cat) + with_remainders = dict((cat, cat) for cat in r) + # Test each category limit here for category in by_cat: - ctrl = CategoryController(category) - limit = ctrl.user_quantity_remaining(self.cart.user) + #ctrl = CategoryController(category) + #limit = ctrl.user_quantity_remaining(self.cart.user) + limit = with_remainders[category].remainder # Get the amount so far in the cart to_add = sum(i[1] for i in by_cat[category]) diff --git a/registrasion/controllers/category.py b/registrasion/controllers/category.py index a986eea7..c3f38ed9 100644 --- a/registrasion/controllers/category.py +++ b/registrasion/controllers/category.py @@ -1,7 +1,11 @@ from registrasion.models import commerce from registrasion.models import inventory +from django.db.models import Case +from django.db.models import F, Q from django.db.models import Sum +from django.db.models import When +from django.db.models import Value class AllProducts(object): @@ -34,25 +38,48 @@ class CategoryController(object): return set(i.category for i in available) + + @classmethod + def attach_user_remainders(cls, user, categories): + ''' + + Return: + queryset(inventory.Product): A queryset containing items from + ``categories``, with an extra attribute -- remainder = the amount + of items from this category that is remaining. + ''' + + ids = [category.id for category in categories] + categories = inventory.Category.objects.filter(id__in=ids) + + cart_filter = ( + Q(product__productitem__cart__user=user) & + Q(product__productitem__cart__status=commerce.Cart.STATUS_PAID) + ) + + quantity = When( + cart_filter, + then='product__productitem__quantity' + ) + + quantity_or_zero = Case( + quantity, + default=Value(0), + ) + + remainder = Case( + When(limit_per_user=None, then=Value(99999999)), + default=F('limit_per_user') - Sum(quantity_or_zero), + ) + + categories = categories.annotate(remainder=remainder) + + return categories + def user_quantity_remaining(self, user): - ''' Returns the number of items from this category that the user may - add in the current cart. ''' + ''' Returns the quantity of this product that the user add in the + current cart. ''' - cat_limit = self.category.limit_per_user + with_remainders = self.attach_user_remainders(user, [self.category]) - if cat_limit is None: - # We don't need to waste the following queries - return 99999999 - - carts = commerce.Cart.objects.filter( - user=user, - status=commerce.Cart.STATUS_PAID, - ) - - items = commerce.ProductItem.objects.filter( - cart__in=carts, - product__category=self.category, - ) - - cat_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 - return cat_limit - cat_count + return with_remainders[0].remainder diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index e1ad9ef7..74783002 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -1,6 +1,11 @@ import itertools +from django.db.models import Case +from django.db.models import F, Q from django.db.models import Sum +from django.db.models import When +from django.db.models import Value + from registrasion.models import commerce from registrasion.models import inventory @@ -16,9 +21,7 @@ class ProductController(object): @classmethod def available_products(cls, user, category=None, products=None): ''' Returns a list of all of the products that are available per - flag conditions from the given categories. - TODO: refactor so that all conditions are tested here and - can_add_with_flags calls this method. ''' + flag conditions from the given categories. ''' if category is None and products is None: raise ValueError("You must provide products or a category") @@ -31,19 +34,18 @@ class ProductController(object): if products is not None: all_products = set(itertools.chain(all_products, products)) - cat_quants = dict( - ( - category, - CategoryController(category).user_quantity_remaining(user), - ) - for category in set(product.category for product in all_products) - ) + categories = set(product.category for product in all_products) + r = CategoryController.attach_user_remainders(user, categories) + cat_quants = dict((c,c) for c in r) + + r = ProductController.attach_user_remainders(user, all_products) + prod_quants = dict((p,p) for p in r) passed_limits = set( product for product in all_products - if cat_quants[product.category] > 0 - if cls(product).user_quantity_remaining(user) > 0 + if cat_quants[product.category].remainder > 0 + if prod_quants[product].remainder > 0 ) failed_and_messages = FlagController.test_flags( @@ -56,26 +58,48 @@ class ProductController(object): return out + + @classmethod + def attach_user_remainders(cls, user, products): + ''' + + Return: + queryset(inventory.Product): A queryset containing items from + ``product``, with an extra attribute -- remainder = the amount of + this item that is remaining. + ''' + + ids = [product.id for product in products] + products = inventory.Product.objects.filter(id__in=ids) + + cart_filter = ( + Q(productitem__cart__user=user) & + Q(productitem__cart__status=commerce.Cart.STATUS_PAID) + ) + + quantity = When( + cart_filter, + then='productitem__quantity' + ) + + quantity_or_zero = Case( + quantity, + default=Value(0), + ) + + remainder = Case( + When(limit_per_user=None, then=Value(99999999)), + default=F('limit_per_user') - Sum(quantity_or_zero), + ) + + products = products.annotate(remainder=remainder) + + return products + def user_quantity_remaining(self, user): ''' Returns the quantity of this product that the user add in the current cart. ''' - prod_limit = self.product.limit_per_user + with_remainders = self.attach_user_remainders(user, [self.product]) - if prod_limit is None: - # Don't need to run the remaining queries - return 999999 # We can do better - - carts = commerce.Cart.objects.filter( - user=user, - status=commerce.Cart.STATUS_PAID, - ) - - items = commerce.ProductItem.objects.filter( - cart__in=carts, - product=self.product, - ) - - prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 - - return prod_limit - prod_count + return with_remainders[0].remainder From fd5cf50fabd84b3dcfa9b926e66dee6fa09fef7c Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 19:52:39 +1000 Subject: [PATCH 09/15] Makes items_purchased do more database work --- .../templatetags/registrasion_tags.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/registrasion/templatetags/registrasion_tags.py b/registrasion/templatetags/registrasion_tags.py index fabc7754..6100d589 100644 --- a/registrasion/templatetags/registrasion_tags.py +++ b/registrasion/templatetags/registrasion_tags.py @@ -4,7 +4,11 @@ from registrasion.controllers.category import CategoryController from collections import namedtuple from django import template +from django.db.models import Case +from django.db.models import Q from django.db.models import Sum +from django.db.models import When +from django.db.models import Value register = template.Library() @@ -99,20 +103,33 @@ def items_purchased(context, category=None): ''' - all_items = commerce.ProductItem.objects.filter( - cart__user=context.request.user, - cart__status=commerce.Cart.STATUS_PAID, - ).select_related("product", "product__category") + in_cart=( + Q(productitem__cart__user=context.request.user) & + Q(productitem__cart__status=commerce.Cart.STATUS_PAID) + ) + + quantities_in_cart = When( + in_cart, + then="productitem__quantity", + ) + + quantities_or_zero = Case( + quantities_in_cart, + default=Value(0), + ) + + products = inventory.Product.objects if category: - all_items = all_items.filter(product__category=category) + products = products.filter(category=category) + + products = products.select_related("category") + products = products.annotate(quantity=Sum(quantities_or_zero)) + products = products.filter(quantity__gt=0) - pq = all_items.values("product").annotate(quantity=Sum("quantity")).all() - products = inventory.Product.objects.all() out = [] - for item in pq: - prod = products.get(pk=item["product"]) - out.append(ProductAndQuantity(prod, item["quantity"])) + for prod in products: + out.append(ProductAndQuantity(prod, prod.quantity)) return out From 4fb569d9353dffaeb067cf314ef55b3561e97388 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 19:58:09 +1000 Subject: [PATCH 10/15] Does more select_related and bulk_create calls --- registrasion/controllers/cart.py | 1 + registrasion/controllers/invoice.py | 22 +++++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 30b36178..7e1e9072 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -358,6 +358,7 @@ class CartController(object): errors.append(ve) items = commerce.ProductItem.objects.filter(cart=cart) + items = items.select_related("product", "product__category") product_quantities = list((i.product, i.quantity) for i in items) try: diff --git a/registrasion/controllers/invoice.py b/registrasion/controllers/invoice.py index 9ef155ef..da737a73 100644 --- a/registrasion/controllers/invoice.py +++ b/registrasion/controllers/invoice.py @@ -99,6 +99,10 @@ class InvoiceController(ForId, object): ) product_items = commerce.ProductItem.objects.filter(cart=cart) + product_items = product_items.select_related( + "product", + "product__category", + ) if len(product_items) == 0: raise ValidationError("Your cart is empty.") @@ -106,29 +110,41 @@ class InvoiceController(ForId, object): product_items = product_items.order_by( "product__category__order", "product__order" ) + discount_items = commerce.DiscountItem.objects.filter(cart=cart) + discount_items = discount_items.select_related( + "discount", + "product", + "product__category", + ) + + line_items = [] + invoice_value = Decimal() for item in product_items: product = item.product - line_item = commerce.LineItem.objects.create( + line_item = commerce.LineItem( invoice=invoice, description="%s - %s" % (product.category.name, product.name), quantity=item.quantity, price=product.price, product=product, ) + line_items.append(line_item) invoice_value += line_item.quantity * line_item.price - for item in discount_items: - line_item = commerce.LineItem.objects.create( + line_item = commerce.LineItem( invoice=invoice, description=item.discount.description, quantity=item.quantity, price=cls.resolve_discount_value(item) * -1, product=item.product, ) + line_items.append(line_item) invoice_value += line_item.quantity * line_item.price + commerce.LineItem.objects.bulk_create(line_items) + invoice.value = invoice_value invoice.save() From 6d52a4c18ff85785b6729073158eedd128158c61 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Thu, 28 Apr 2016 20:15:21 +1000 Subject: [PATCH 11/15] More low-hanging query optimisations --- registrasion/controllers/cart.py | 37 ++++++++++++++++---------- registrasion/controllers/conditions.py | 35 +++++++++++++----------- registrasion/controllers/flag.py | 22 +++++++-------- registrasion/views.py | 7 +++-- 4 files changed, 58 insertions(+), 43 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 7e1e9072..83e08ef8 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -8,6 +8,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError from django.db import transaction from django.db.models import Max +from django.db.models import Q from django.utils import timezone from registrasion.exceptions import CartValidationError @@ -84,6 +85,8 @@ class CartController(object): revision is increased. ''' + # TODO cache carts mid-batch? + ctrl = cls.for_user(user) _id = ctrl.cart.id @@ -180,22 +183,28 @@ class CartController(object): # Validate that the limits we're adding are OK self._test_limits(all_product_quantities) + new_items = [] + products = [] for product, quantity in product_quantities: - try: - product_item = commerce.ProductItem.objects.get( - cart=self.cart, - product=product, - ) - product_item.quantity = quantity - product_item.save() - except ObjectDoesNotExist: - commerce.ProductItem.objects.create( - cart=self.cart, - product=product, - quantity=quantity, - ) + products.append(product) - items_in_cart.filter(quantity=0).delete() + if quantity == 0: + continue + + item = commerce.ProductItem( + cart=self.cart, + product=product, + quantity=quantity, + ) + new_items.append(item) + + to_delete = ( + Q(quantity=0) | + Q(product__in=products) + ) + + items_in_cart.filter(to_delete).delete() + commerce.ProductItem.objects.bulk_create(new_items) def _test_limits(self, product_quantities): ''' Tests that the quantity changes we intend to make do not violate diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index 0a2b3c4a..2480a4fc 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -154,12 +154,17 @@ class CategoryConditionController(IsMetByFilter, ConditionController): product from a category invoking that item's condition in one of their carts. ''' - items = commerce.ProductItem.objects.filter(cart__user=user) - items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) - items = items.select_related("product", "product__category") - categories = [item.product.category for item in items] + in_user_carts = Q( + enabling_category__product__productitem__cart__user=user + ) + released = commerce.Cart.STATUS_RELEASED + in_released_carts = Q( + enabling_category__product__productitem__cart__status=released + ) + queryset = queryset.filter(in_user_carts) + queryset = queryset.exclude(in_released_carts) - return queryset.filter(enabling_category__in=categories) + return queryset class ProductConditionController(IsMetByFilter, ConditionController): @@ -171,12 +176,15 @@ class ProductConditionController(IsMetByFilter, ConditionController): ''' Returns all of the items from queryset where the user has a product invoking that item's condition in one of their carts. ''' - items = commerce.ProductItem.objects.filter(cart__user=user) - items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED) - items = items.select_related("product", "product__category") - products = [item.product for item in items] + in_user_carts = Q(enabling_products__productitem__cart__user=user) + released = commerce.Cart.STATUS_RELEASED + in_released_carts = Q( + enabling_products__productitem__cart__status=released + ) + queryset = queryset.filter(in_user_carts) + queryset = queryset.exclude(in_released_carts) - return queryset.filter(enabling_products__in=products) + return queryset class TimeOrStockLimitConditionController( @@ -287,9 +295,4 @@ class VoucherConditionController(IsMetByFilter, ConditionController): ''' Returns all of the items from queryset where the user has entered a voucher that invokes that item's condition in one of their carts. ''' - carts = commerce.Cart.objects.filter( - user=user, - ) - vouchers = [cart.vouchers.all() for cart in carts] - - return queryset.filter(voucher__in=itertools.chain(*vouchers)) + return queryset.filter(voucher__cart__user=user) diff --git a/registrasion/controllers/flag.py b/registrasion/controllers/flag.py index a67a0b94..aa11d53e 100644 --- a/registrasion/controllers/flag.py +++ b/registrasion/controllers/flag.py @@ -4,6 +4,7 @@ import operator from collections import defaultdict from collections import namedtuple from django.db.models import Count +from django.db.models import Q from .conditions import ConditionController @@ -83,18 +84,16 @@ class FlagController(object): # Get all products covered by this condition, and the products # from the categories covered by this condition - cond_products = condition.products.all() - from_category = inventory.Product.objects.filter( - category__in=condition.categories.all(), - ).all() - all_products = cond_products | from_category + + ids = [product.id for product in products] + all_products = inventory.Product.objects.filter(id__in=ids) + cond = ( + Q(flagbase_set=condition) | + Q(category__in=condition.categories.all()) + ) + + all_products = all_products.filter(cond) all_products = all_products.select_related("category") - # Remove the products that we aren't asking about - all_products = [ - product - for product in all_products - if product in products - ] if quantities: consumed = sum(quantities[i] for i in all_products) @@ -221,6 +220,7 @@ _ConditionsCount = namedtuple( ) +# TODO: this should be cacheable. class FlagCounter(_FlagCounter): @classmethod diff --git a/registrasion/views.py b/registrasion/views.py index 41c4a0d6..a7917406 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -445,14 +445,17 @@ def _handle_products(request, category, products, prefix): def _set_quantities_from_products_form(products_form, current_cart): quantities = list(products_form.product_quantities()) - + id_to_quantity = dict(i[:2] for i in quantities) pks = [i[0] for i in quantities] products = inventory.Product.objects.filter( id__in=pks, ).select_related("category") + + + # TODO: This is fundamentally dumb product_quantities = [ - (products.get(pk=i[0]), i[1]) for i in quantities + (product, id_to_quantity[product.id]) for product in products ] field_names = dict( (i[0][0], i[1][2]) for i in zip(product_quantities, quantities) From 02fe88a4e4db19943791a7a07a20890685d1bf5e Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 29 Apr 2016 10:08:21 +1000 Subject: [PATCH 12/15] Tests and fixes for a bug where discount quantities did not respect per-line item quantities. --- registrasion/controllers/discount.py | 44 ++++++++++++++++++---------- registrasion/tests/test_discount.py | 23 +++++++++++++++ 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index 1c7fa59f..164d95cc 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -5,7 +5,7 @@ from registrasion.models import commerce from registrasion.models import conditions from django.db.models import Case -from django.db.models import Q +from django.db.models import F, Q from django.db.models import Sum from django.db.models import Value from django.db.models import When @@ -64,9 +64,7 @@ class DiscountController(object): discount = clause.discount cond = ConditionController.for_condition(discount) - past_use_count = discount.past_use_count - - + past_use_count = clause.past_use_count if past_use_count >= clause.quantity: # This clause has exceeded its use count pass @@ -139,7 +137,6 @@ class DiscountController(object): discounts = discounttype.objects.filter(id__in=valid_discounts) ctrl = ConditionController.for_type(discounttype) discounts = ctrl.pre_filter(discounts, user) - discounts = cls._annotate_with_past_uses(discounts, user) all_subsets.append(discounts) filtered_discounts = list(itertools.chain(*all_subsets)) @@ -148,11 +145,17 @@ class DiscountController(object): # (contains annotations needed in the future) from_filter = dict((i.id, i) for i in filtered_discounts) - # The set of all potential discounts - discount_clauses = set(itertools.chain( + clause_sets = ( product_discounts.filter(discount__in=filtered_discounts), all_category_discounts.filter(discount__in=filtered_discounts), - )) + ) + + clause_sets = ( + cls._annotate_with_past_uses(i, user) for i in clause_sets + ) + + # The set of all potential discount clauses + discount_clauses = set(itertools.chain(*clause_sets)) # Replace discounts with the filtered ones # These are the correct subclasses (saves query later on), and have @@ -164,15 +167,26 @@ class DiscountController(object): @classmethod def _annotate_with_past_uses(cls, queryset, user): - ''' Annotates the queryset with a usage count for that discount by the - given user. ''' + ''' Annotates the queryset with a usage count for that discount claus + by the given user. ''' + + if queryset.model == conditions.DiscountForCategory: + matches = ( + Q(category=F('discount__discountitem__product__category')) + ) + elif queryset.model == conditions.DiscountForProduct: + matches = ( + Q(product=F('discount__discountitem__product')) + ) + + in_carts = ( + Q(discount__discountitem__cart__user=user) & + Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID) + ) past_use_quantity = When( - ( - Q(discountitem__cart__user=user) & - Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) - ), - then="discountitem__quantity", + in_carts & matches, + then="discount__discountitem__quantity", ) past_use_quantity_or_zero = Case( diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index 4b92c81b..d7920a10 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -398,6 +398,29 @@ class DiscountTestCase(RegistrationCartTestCase): ) self.assertEqual(2, len(discounts)) + def test_product_discount_applied_on_different_invoices(self): + # quantity=1 means "quantity per product" + self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1) + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4], + ) + self.assertEqual(2, len(discounts)) + # adding one of PROD_3 should make it no longer an available discount. + cart.add_to_cart(self.PROD_3, 1) + cart.next_cart() + + # should still have (and only have) the discount for prod_4 + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4], + ) + self.assertEqual(1, len(discounts)) + def test_discounts_are_released_by_refunds(self): self.add_discount_prod_1_includes_prod_2(quantity=2) cart = TestingCartController.for_user(self.USER_1) From 4eff8194f96f59d311f40280c2e9bdcd38731d22 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 29 Apr 2016 10:48:51 +1000 Subject: [PATCH 13/15] Reduces CartController re-loading when batching operations --- registrasion/controllers/cart.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 83e08ef8..b445a802 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -67,6 +67,7 @@ class CartController(object): # Marks the carts that are currently in batches + _FOR_USER = {} _BATCH_COUNT = collections.defaultdict(int) _MODIFIED_CARTS = set() @@ -85,10 +86,11 @@ class CartController(object): revision is increased. ''' - # TODO cache carts mid-batch? + if user not in cls._FOR_USER: + _ctrl = cls.for_user(user) + cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id) - ctrl = cls.for_user(user) - _id = ctrl.cart.id + ctrl, _id = cls._FOR_USER[user] cls._BATCH_COUNT[_id] += 1 try: @@ -108,10 +110,15 @@ class CartController(object): # Only end on the outermost batch marker, and only if # it excited cleanly, and a modification occurred modified = _id in cls._MODIFIED_CARTS - if modified and cls._BATCH_COUNT[_id] == 0 and success: + outermost = cls._BATCH_COUNT[_id] == 0 + if modified and outermost and success: ctrl._end_batch() cls._MODIFIED_CARTS.remove(_id) + # Clear out the cache on the outermost operation + if outermost: + del cls._FOR_USER[user] + def _fail_if_cart_is_not_active(self): self.cart.refresh_from_db() if self.cart.status != commerce.Cart.STATUS_ACTIVE: From 135f2fb47b13dcbd31b640449dabe0cdcee80df9 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 29 Apr 2016 10:57:33 +1000 Subject: [PATCH 14/15] Refactors discounts validation in terms of available_discounts --- registrasion/controllers/cart.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index b445a802..5c5cf4f8 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -390,19 +390,22 @@ class CartController(object): # Validate the discounts # TODO: refactor in terms of available_discounts # why aren't we doing that here?! - discount_items = commerce.DiscountItem.objects.filter(cart=cart) - seen_discounts = set() + # def available_discounts(cls, user, categories, products): + + products = [i.product for i in items] + discounts_with_quantity = DiscountController.available_discounts( + user, + [], + products, + ) + discounts = set(i.discount.id for i in discounts_with_quantity) + + discount_items = commerce.DiscountItem.objects.filter(cart=cart) for discount_item in discount_items: discount = discount_item.discount - if discount in seen_discounts: - continue - seen_discounts.add(discount) - real_discount = conditions.DiscountBase.objects.get_subclass( - pk=discount.pk) - cond = ConditionController.for_condition(real_discount) - if not cond.is_met(user): + if discount.id not in discounts: errors.append( ValidationError("Discounts are no longer available") ) From b40505117f4c9e2abe4464ffa6505d23a1440ac2 Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 29 Apr 2016 11:22:56 +1000 Subject: [PATCH 15/15] Fixes flake8 errors arising from rebase --- registrasion/controllers/cart.py | 5 ----- registrasion/controllers/category.py | 1 - registrasion/controllers/conditions.py | 12 +++--------- registrasion/controllers/discount.py | 1 - registrasion/controllers/product.py | 5 ++--- registrasion/templatetags/registrasion_tags.py | 2 +- registrasion/views.py | 3 --- 7 files changed, 6 insertions(+), 23 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 5c5cf4f8..dbf7e8a0 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -17,7 +17,6 @@ from registrasion.models import conditions from registrasion.models import inventory from .category import CategoryController -from .conditions import ConditionController from .discount import DiscountController from .flag import FlagController from .product import ProductController @@ -65,7 +64,6 @@ class CartController(object): ) return cls(existing) - # Marks the carts that are currently in batches _FOR_USER = {} _BATCH_COUNT = collections.defaultdict(int) @@ -156,7 +154,6 @@ class CartController(object): ''' - self.cart.refresh_from_db() self._recalculate_discounts() @@ -250,8 +247,6 @@ class CartController(object): # Test each category limit here for category in by_cat: - #ctrl = CategoryController(category) - #limit = ctrl.user_quantity_remaining(self.cart.user) limit = with_remainders[category].remainder # Get the amount so far in the cart diff --git a/registrasion/controllers/category.py b/registrasion/controllers/category.py index c3f38ed9..9db8ca9e 100644 --- a/registrasion/controllers/category.py +++ b/registrasion/controllers/category.py @@ -38,7 +38,6 @@ class CategoryController(object): return set(i.category for i in available) - @classmethod def attach_user_remainders(cls, user, categories): ''' diff --git a/registrasion/controllers/conditions.py b/registrasion/controllers/conditions.py index 2480a4fc..51078016 100644 --- a/registrasion/controllers/conditions.py +++ b/registrasion/controllers/conditions.py @@ -1,5 +1,3 @@ -import itertools - from django.db.models import Case from django.db.models import F, Q from django.db.models import Sum @@ -9,8 +7,6 @@ from django.utils import timezone from registrasion.models import commerce from registrasion.models import conditions -from registrasion.models import inventory - _BIG_QUANTITY = 99999999 # A big quantity @@ -134,8 +130,6 @@ class RemainderSetByFilter(object): if hasattr(self.condition, "remainder"): return self.condition.remainder - - # Mark self.condition with a remainder qs = type(self.condition).objects.filter(pk=self.condition.id) qs = self.pre_filter(qs, user) @@ -188,9 +182,9 @@ class ProductConditionController(IsMetByFilter, ConditionController): class TimeOrStockLimitConditionController( - RemainderSetByFilter, - ConditionController, - ): + RemainderSetByFilter, + ConditionController, + ): ''' Common condition tests for TimeOrStockLimit Flag and Discount.''' diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index 164d95cc..f4f88ed2 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -50,7 +50,6 @@ class DiscountController(object): categories and products. The discounts also list the available quantity for this user, not including products that are pending purchase. ''' - filtered_clauses = cls._filtered_discounts(user, categories, products) discounts = [] diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 74783002..610c7f0d 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -36,10 +36,10 @@ class ProductController(object): categories = set(product.category for product in all_products) r = CategoryController.attach_user_remainders(user, categories) - cat_quants = dict((c,c) for c in r) + cat_quants = dict((c, c) for c in r) r = ProductController.attach_user_remainders(user, all_products) - prod_quants = dict((p,p) for p in r) + prod_quants = dict((p, p) for p in r) passed_limits = set( product @@ -58,7 +58,6 @@ class ProductController(object): return out - @classmethod def attach_user_remainders(cls, user, products): ''' diff --git a/registrasion/templatetags/registrasion_tags.py b/registrasion/templatetags/registrasion_tags.py index 6100d589..9074781c 100644 --- a/registrasion/templatetags/registrasion_tags.py +++ b/registrasion/templatetags/registrasion_tags.py @@ -103,7 +103,7 @@ def items_purchased(context, category=None): ''' - in_cart=( + in_cart = ( Q(productitem__cart__user=context.request.user) & Q(productitem__cart__status=commerce.Cart.STATUS_PAID) ) diff --git a/registrasion/views.py b/registrasion/views.py index a7917406..a4dcceac 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -451,9 +451,6 @@ def _set_quantities_from_products_form(products_form, current_cart): id__in=pks, ).select_related("category") - - - # TODO: This is fundamentally dumb product_quantities = [ (product, id_to_quantity[product.id]) for product in products ]