Merge branch 'query-optimisation'
This commit is contained in:
		
						commit
						6956c78b0d
					
				
					 14 changed files with 1064 additions and 449 deletions
				
			
		|  | @ -1,6 +1,6 @@ | ||||||
| import collections | import collections | ||||||
|  | import contextlib | ||||||
| import datetime | import datetime | ||||||
| import discount |  | ||||||
| import functools | import functools | ||||||
| import itertools | import itertools | ||||||
| 
 | 
 | ||||||
|  | @ -8,6 +8,7 @@ from django.core.exceptions import ObjectDoesNotExist | ||||||
| from django.core.exceptions import ValidationError | from django.core.exceptions import ValidationError | ||||||
| from django.db import transaction | from django.db import transaction | ||||||
| from django.db.models import Max | from django.db.models import Max | ||||||
|  | from django.db.models import Q | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| 
 | 
 | ||||||
| from registrasion.exceptions import CartValidationError | from registrasion.exceptions import CartValidationError | ||||||
|  | @ -15,19 +16,27 @@ from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| from registrasion.models import inventory | from registrasion.models import inventory | ||||||
| 
 | 
 | ||||||
| from category import CategoryController | from .category import CategoryController | ||||||
| from conditions import ConditionController | from .discount import DiscountController | ||||||
| from product import ProductController | from .flag import FlagController | ||||||
|  | from .product import ProductController | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _modifies_cart(func): | def _modifies_cart(func): | ||||||
|     ''' Decorator that makes the wrapped function raise ValidationError |     ''' Decorator that makes the wrapped function raise ValidationError | ||||||
|     if we're doing something that could modify the cart. ''' |     if we're doing something that could modify the cart. | ||||||
|  | 
 | ||||||
|  |     It also wraps the execution of this function in a database transaction, | ||||||
|  |     and marks the boundaries of a cart operations batch. | ||||||
|  |     ''' | ||||||
| 
 | 
 | ||||||
|     @functools.wraps(func) |     @functools.wraps(func) | ||||||
|     def inner(self, *a, **k): |     def inner(self, *a, **k): | ||||||
|         self._fail_if_cart_is_not_active() |         self._fail_if_cart_is_not_active() | ||||||
|         return func(self, *a, **k) |         with transaction.atomic(): | ||||||
|  |             with CartController.operations_batch(self.cart.user) as mark: | ||||||
|  |                 mark.mark = True  # Marker that we've modified the cart | ||||||
|  |                 return func(self, *a, **k) | ||||||
| 
 | 
 | ||||||
|     return inner |     return inner | ||||||
| 
 | 
 | ||||||
|  | @ -55,13 +64,65 @@ class CartController(object): | ||||||
|             ) |             ) | ||||||
|         return cls(existing) |         return cls(existing) | ||||||
| 
 | 
 | ||||||
|  |     # Marks the carts that are currently in batches | ||||||
|  |     _FOR_USER = {} | ||||||
|  |     _BATCH_COUNT = collections.defaultdict(int) | ||||||
|  |     _MODIFIED_CARTS = set() | ||||||
|  | 
 | ||||||
|  |     class _ModificationMarker(object): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     @contextlib.contextmanager | ||||||
|  |     def operations_batch(cls, user): | ||||||
|  |         ''' Marks the boundary for a batch of operations on a user's cart. | ||||||
|  | 
 | ||||||
|  |         These markers can be nested. Only on exiting the outermost marker will | ||||||
|  |         a batch be ended. | ||||||
|  | 
 | ||||||
|  |         When a batch is ended, discounts are recalculated, and the cart's | ||||||
|  |         revision is increased. | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         if user not in cls._FOR_USER: | ||||||
|  |             _ctrl = cls.for_user(user) | ||||||
|  |             cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id) | ||||||
|  | 
 | ||||||
|  |         ctrl, _id = cls._FOR_USER[user] | ||||||
|  | 
 | ||||||
|  |         cls._BATCH_COUNT[_id] += 1 | ||||||
|  |         try: | ||||||
|  |             success = False | ||||||
|  | 
 | ||||||
|  |             marker = cls._ModificationMarker() | ||||||
|  |             yield marker | ||||||
|  | 
 | ||||||
|  |             if hasattr(marker, "mark"): | ||||||
|  |                 cls._MODIFIED_CARTS.add(_id) | ||||||
|  | 
 | ||||||
|  |             success = True | ||||||
|  |         finally: | ||||||
|  | 
 | ||||||
|  |             cls._BATCH_COUNT[_id] -= 1 | ||||||
|  | 
 | ||||||
|  |             # Only end on the outermost batch marker, and only if | ||||||
|  |             # it excited cleanly, and a modification occurred | ||||||
|  |             modified = _id in cls._MODIFIED_CARTS | ||||||
|  |             outermost = cls._BATCH_COUNT[_id] == 0 | ||||||
|  |             if modified and outermost and success: | ||||||
|  |                 ctrl._end_batch() | ||||||
|  |                 cls._MODIFIED_CARTS.remove(_id) | ||||||
|  | 
 | ||||||
|  |             # Clear out the cache on the outermost operation | ||||||
|  |             if outermost: | ||||||
|  |                 del cls._FOR_USER[user] | ||||||
|  | 
 | ||||||
|     def _fail_if_cart_is_not_active(self): |     def _fail_if_cart_is_not_active(self): | ||||||
|         self.cart.refresh_from_db() |         self.cart.refresh_from_db() | ||||||
|         if self.cart.status != commerce.Cart.STATUS_ACTIVE: |         if self.cart.status != commerce.Cart.STATUS_ACTIVE: | ||||||
|             raise ValidationError("You can only amend active carts.") |             raise ValidationError("You can only amend active carts.") | ||||||
| 
 | 
 | ||||||
|     @_modifies_cart |     def _autoextend_reservation(self): | ||||||
|     def extend_reservation(self): |  | ||||||
|         ''' Updates the cart's time last updated value, which is used to |         ''' Updates the cart's time last updated value, which is used to | ||||||
|         determine whether the cart has reserved the items and discounts it |         determine whether the cart has reserved the items and discounts it | ||||||
|         holds. ''' |         holds. ''' | ||||||
|  | @ -83,21 +144,25 @@ class CartController(object): | ||||||
|         self.cart.time_last_updated = timezone.now() |         self.cart.time_last_updated = timezone.now() | ||||||
|         self.cart.reservation_duration = max(reservations) |         self.cart.reservation_duration = max(reservations) | ||||||
| 
 | 
 | ||||||
|     @_modifies_cart |     def _end_batch(self): | ||||||
|     def end_batch(self): |  | ||||||
|         ''' Performs operations that occur occur at the end of a batch of |         ''' Performs operations that occur occur at the end of a batch of | ||||||
|         product changes/voucher applications etc. |         product changes/voucher applications etc. | ||||||
|         THIS SHOULD BE PRIVATE | 
 | ||||||
|  |         You need to call this after you've finished modifying the user's cart. | ||||||
|  |         This is normally done by wrapping a block of code using | ||||||
|  |         ``operations_batch``. | ||||||
|  | 
 | ||||||
|         ''' |         ''' | ||||||
| 
 | 
 | ||||||
|         self.recalculate_discounts() |         self.cart.refresh_from_db() | ||||||
| 
 | 
 | ||||||
|         self.extend_reservation() |         self._recalculate_discounts() | ||||||
|  | 
 | ||||||
|  |         self._autoextend_reservation() | ||||||
|         self.cart.revision += 1 |         self.cart.revision += 1 | ||||||
|         self.cart.save() |         self.cart.save() | ||||||
| 
 | 
 | ||||||
|     @_modifies_cart |     @_modifies_cart | ||||||
|     @transaction.atomic |  | ||||||
|     def set_quantities(self, product_quantities): |     def set_quantities(self, product_quantities): | ||||||
|         ''' Sets the quantities on each of the products on each of the |         ''' Sets the quantities on each of the products on each of the | ||||||
|         products specified. Raises an exception (ValidationError) if a limit |         products specified. Raises an exception (ValidationError) if a limit | ||||||
|  | @ -122,24 +187,28 @@ class CartController(object): | ||||||
|         # Validate that the limits we're adding are OK |         # Validate that the limits we're adding are OK | ||||||
|         self._test_limits(all_product_quantities) |         self._test_limits(all_product_quantities) | ||||||
| 
 | 
 | ||||||
|  |         new_items = [] | ||||||
|  |         products = [] | ||||||
|         for product, quantity in product_quantities: |         for product, quantity in product_quantities: | ||||||
|             try: |             products.append(product) | ||||||
|                 product_item = commerce.ProductItem.objects.get( |  | ||||||
|                     cart=self.cart, |  | ||||||
|                     product=product, |  | ||||||
|                 ) |  | ||||||
|                 product_item.quantity = quantity |  | ||||||
|                 product_item.save() |  | ||||||
|             except ObjectDoesNotExist: |  | ||||||
|                 commerce.ProductItem.objects.create( |  | ||||||
|                     cart=self.cart, |  | ||||||
|                     product=product, |  | ||||||
|                     quantity=quantity, |  | ||||||
|                 ) |  | ||||||
| 
 | 
 | ||||||
|         items_in_cart.filter(quantity=0).delete() |             if quantity == 0: | ||||||
|  |                 continue | ||||||
| 
 | 
 | ||||||
|         self.end_batch() |             item = commerce.ProductItem( | ||||||
|  |                 cart=self.cart, | ||||||
|  |                 product=product, | ||||||
|  |                 quantity=quantity, | ||||||
|  |             ) | ||||||
|  |             new_items.append(item) | ||||||
|  | 
 | ||||||
|  |         to_delete = ( | ||||||
|  |             Q(quantity=0) | | ||||||
|  |             Q(product__in=products) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         items_in_cart.filter(to_delete).delete() | ||||||
|  |         commerce.ProductItem.objects.bulk_create(new_items) | ||||||
| 
 | 
 | ||||||
|     def _test_limits(self, product_quantities): |     def _test_limits(self, product_quantities): | ||||||
|         ''' Tests that the quantity changes we intend to make do not violate |         ''' Tests that the quantity changes we intend to make do not violate | ||||||
|  | @ -147,13 +216,17 @@ class CartController(object): | ||||||
| 
 | 
 | ||||||
|         errors = [] |         errors = [] | ||||||
| 
 | 
 | ||||||
|  |         # Pre-annotate products | ||||||
|  |         products = [p for (p, q) in product_quantities] | ||||||
|  |         r = ProductController.attach_user_remainders(self.cart.user, products) | ||||||
|  |         with_remainders = dict((p, p) for p in r) | ||||||
|  | 
 | ||||||
|         # Test each product limit here |         # Test each product limit here | ||||||
|         for product, quantity in product_quantities: |         for product, quantity in product_quantities: | ||||||
|             if quantity < 0: |             if quantity < 0: | ||||||
|                 errors.append((product, "Value must be zero or greater.")) |                 errors.append((product, "Value must be zero or greater.")) | ||||||
| 
 | 
 | ||||||
|             prod = ProductController(product) |             limit = with_remainders[product].remainder | ||||||
|             limit = prod.user_quantity_remaining(self.cart.user) |  | ||||||
| 
 | 
 | ||||||
|             if quantity > limit: |             if quantity > limit: | ||||||
|                 errors.append(( |                 errors.append(( | ||||||
|  | @ -168,10 +241,13 @@ class CartController(object): | ||||||
|         for product, quantity in product_quantities: |         for product, quantity in product_quantities: | ||||||
|             by_cat[product.category].append((product, quantity)) |             by_cat[product.category].append((product, quantity)) | ||||||
| 
 | 
 | ||||||
|  |         # Pre-annotate categories | ||||||
|  |         r = CategoryController.attach_user_remainders(self.cart.user, by_cat) | ||||||
|  |         with_remainders = dict((cat, cat) for cat in r) | ||||||
|  | 
 | ||||||
|         # Test each category limit here |         # Test each category limit here | ||||||
|         for category in by_cat: |         for category in by_cat: | ||||||
|             ctrl = CategoryController(category) |             limit = with_remainders[category].remainder | ||||||
|             limit = ctrl.user_quantity_remaining(self.cart.user) |  | ||||||
| 
 | 
 | ||||||
|             # Get the amount so far in the cart |             # Get the amount so far in the cart | ||||||
|             to_add = sum(i[1] for i in by_cat[category]) |             to_add = sum(i[1] for i in by_cat[category]) | ||||||
|  | @ -185,7 +261,7 @@ class CartController(object): | ||||||
|                 )) |                 )) | ||||||
| 
 | 
 | ||||||
|         # Test the flag conditions |         # Test the flag conditions | ||||||
|         errs = ConditionController.test_flags( |         errs = FlagController.test_flags( | ||||||
|             self.cart.user, |             self.cart.user, | ||||||
|             product_quantities=product_quantities, |             product_quantities=product_quantities, | ||||||
|         ) |         ) | ||||||
|  | @ -212,7 +288,6 @@ class CartController(object): | ||||||
| 
 | 
 | ||||||
|         # If successful... |         # If successful... | ||||||
|         self.cart.vouchers.add(voucher) |         self.cart.vouchers.add(voucher) | ||||||
|         self.end_batch() |  | ||||||
| 
 | 
 | ||||||
|     def _test_voucher(self, voucher): |     def _test_voucher(self, voucher): | ||||||
|         ''' Tests whether this voucher is allowed to be applied to this cart. |         ''' Tests whether this voucher is allowed to be applied to this cart. | ||||||
|  | @ -294,6 +369,7 @@ class CartController(object): | ||||||
|             errors.append(ve) |             errors.append(ve) | ||||||
| 
 | 
 | ||||||
|         items = commerce.ProductItem.objects.filter(cart=cart) |         items = commerce.ProductItem.objects.filter(cart=cart) | ||||||
|  |         items = items.select_related("product", "product__category") | ||||||
| 
 | 
 | ||||||
|         product_quantities = list((i.product, i.quantity) for i in items) |         product_quantities = list((i.product, i.quantity) for i in items) | ||||||
|         try: |         try: | ||||||
|  | @ -307,19 +383,24 @@ class CartController(object): | ||||||
|             self._append_errors(errors, ve) |             self._append_errors(errors, ve) | ||||||
| 
 | 
 | ||||||
|         # Validate the discounts |         # Validate the discounts | ||||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) |         # TODO: refactor in terms of available_discounts | ||||||
|         seen_discounts = set() |         # why aren't we doing that here?! | ||||||
| 
 | 
 | ||||||
|  |         #     def available_discounts(cls, user, categories, products): | ||||||
|  | 
 | ||||||
|  |         products = [i.product for i in items] | ||||||
|  |         discounts_with_quantity = DiscountController.available_discounts( | ||||||
|  |             user, | ||||||
|  |             [], | ||||||
|  |             products, | ||||||
|  |         ) | ||||||
|  |         discounts = set(i.discount.id for i in discounts_with_quantity) | ||||||
|  | 
 | ||||||
|  |         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||||
|         for discount_item in discount_items: |         for discount_item in discount_items: | ||||||
|             discount = discount_item.discount |             discount = discount_item.discount | ||||||
|             if discount in seen_discounts: |  | ||||||
|                 continue |  | ||||||
|             seen_discounts.add(discount) |  | ||||||
|             real_discount = conditions.DiscountBase.objects.get_subclass( |  | ||||||
|                 pk=discount.pk) |  | ||||||
|             cond = ConditionController.for_condition(real_discount) |  | ||||||
| 
 | 
 | ||||||
|             if not cond.is_met(user): |             if discount.id not in discounts: | ||||||
|                 errors.append( |                 errors.append( | ||||||
|                     ValidationError("Discounts are no longer available") |                     ValidationError("Discounts are no longer available") | ||||||
|                 ) |                 ) | ||||||
|  | @ -328,7 +409,6 @@ class CartController(object): | ||||||
|             raise ValidationError(errors) |             raise ValidationError(errors) | ||||||
| 
 | 
 | ||||||
|     @_modifies_cart |     @_modifies_cart | ||||||
|     @transaction.atomic |  | ||||||
|     def fix_simple_errors(self): |     def fix_simple_errors(self): | ||||||
|         ''' This attempts to fix the easy errors raised by ValidationError. |         ''' This attempts to fix the easy errors raised by ValidationError. | ||||||
|         This includes removing items from the cart that are no longer |         This includes removing items from the cart that are no longer | ||||||
|  | @ -360,11 +440,9 @@ class CartController(object): | ||||||
| 
 | 
 | ||||||
|         self.set_quantities(zeros) |         self.set_quantities(zeros) | ||||||
| 
 | 
 | ||||||
|     @_modifies_cart |  | ||||||
|     @transaction.atomic |     @transaction.atomic | ||||||
|     def recalculate_discounts(self): |     def _recalculate_discounts(self): | ||||||
|         ''' Calculates all of the discounts available for this product. |         ''' Calculates all of the discounts available for this product.''' | ||||||
|         ''' |  | ||||||
| 
 | 
 | ||||||
|         # Delete the existing entries. |         # Delete the existing entries. | ||||||
|         commerce.DiscountItem.objects.filter(cart=self.cart).delete() |         commerce.DiscountItem.objects.filter(cart=self.cart).delete() | ||||||
|  | @ -374,7 +452,11 @@ class CartController(object): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         products = [i.product for i in product_items] |         products = [i.product for i in product_items] | ||||||
|         discounts = discount.available_discounts(self.cart.user, [], products) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.cart.user, | ||||||
|  |             [], | ||||||
|  |             products, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # The highest-value discounts will apply to the highest-value |         # The highest-value discounts will apply to the highest-value | ||||||
|         # products first. |         # products first. | ||||||
|  |  | ||||||
|  | @ -1,7 +1,11 @@ | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import inventory | from registrasion.models import inventory | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import F, Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import When | ||||||
|  | from django.db.models import Value | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AllProducts(object): | class AllProducts(object): | ||||||
|  | @ -34,25 +38,47 @@ class CategoryController(object): | ||||||
| 
 | 
 | ||||||
|         return set(i.category for i in available) |         return set(i.category for i in available) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def attach_user_remainders(cls, user, categories): | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         Return: | ||||||
|  |             queryset(inventory.Product): A queryset containing items from | ||||||
|  |             ``categories``, with an extra attribute -- remainder = the amount | ||||||
|  |             of items from this category that is remaining. | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         ids = [category.id for category in categories] | ||||||
|  |         categories = inventory.Category.objects.filter(id__in=ids) | ||||||
|  | 
 | ||||||
|  |         cart_filter = ( | ||||||
|  |             Q(product__productitem__cart__user=user) & | ||||||
|  |             Q(product__productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity = When( | ||||||
|  |             cart_filter, | ||||||
|  |             then='product__productitem__quantity' | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             quantity, | ||||||
|  |             default=Value(0), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         remainder = Case( | ||||||
|  |             When(limit_per_user=None, then=Value(99999999)), | ||||||
|  |             default=F('limit_per_user') - Sum(quantity_or_zero), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         categories = categories.annotate(remainder=remainder) | ||||||
|  | 
 | ||||||
|  |         return categories | ||||||
|  | 
 | ||||||
|     def user_quantity_remaining(self, user): |     def user_quantity_remaining(self, user): | ||||||
|         ''' Returns the number of items from this category that the user may |         ''' Returns the quantity of this product that the user add in the | ||||||
|         add in the current cart. ''' |         current cart. ''' | ||||||
| 
 | 
 | ||||||
|         cat_limit = self.category.limit_per_user |         with_remainders = self.attach_user_remainders(user, [self.category]) | ||||||
| 
 | 
 | ||||||
|         if cat_limit is None: |         return with_remainders[0].remainder | ||||||
|             # We don't need to waste the following queries |  | ||||||
|             return 99999999 |  | ||||||
| 
 |  | ||||||
|         carts = commerce.Cart.objects.filter( |  | ||||||
|             user=user, |  | ||||||
|             status=commerce.Cart.STATUS_PAID, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         items = commerce.ProductItem.objects.filter( |  | ||||||
|             cart__in=carts, |  | ||||||
|             product__category=self.category, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         cat_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 |  | ||||||
|         return cat_limit - cat_count |  | ||||||
|  |  | ||||||
|  | @ -1,36 +1,27 @@ | ||||||
| import itertools | from django.db.models import Case | ||||||
| import operator | from django.db.models import F, Q | ||||||
| 
 |  | ||||||
| from collections import defaultdict |  | ||||||
| from collections import namedtuple |  | ||||||
| 
 |  | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import Value | ||||||
|  | from django.db.models import When | ||||||
| from django.utils import timezone | from django.utils import timezone | ||||||
| 
 | 
 | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| from registrasion.models import inventory |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ConditionAndRemainder = namedtuple( | _BIG_QUANTITY = 99999999  # A big quantity | ||||||
|     "ConditionAndRemainder", |  | ||||||
|     ( |  | ||||||
|         "condition", |  | ||||||
|         "remainder", |  | ||||||
|     ), |  | ||||||
| ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ConditionController(object): | class ConditionController(object): | ||||||
|     ''' Base class for testing conditions that activate Flag |     ''' Base class for testing conditions that activate Flag | ||||||
|     or Discount objects. ''' |     or Discount objects. ''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self): |     def __init__(self, condition): | ||||||
|         pass |         self.condition = condition | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def for_condition(condition): |     def _controllers(): | ||||||
|         CONTROLLERS = { |         return { | ||||||
|             conditions.CategoryFlag: CategoryConditionController, |             conditions.CategoryFlag: CategoryConditionController, | ||||||
|             conditions.IncludedProductDiscount: ProductConditionController, |             conditions.IncludedProductDiscount: ProductConditionController, | ||||||
|             conditions.ProductFlag: ProductConditionController, |             conditions.ProductFlag: ProductConditionController, | ||||||
|  | @ -42,137 +33,49 @@ class ConditionController(object): | ||||||
|             conditions.VoucherFlag: VoucherConditionController, |             conditions.VoucherFlag: VoucherConditionController, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def for_type(cls): | ||||||
|  |         return ConditionController._controllers()[cls] | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def for_condition(condition): | ||||||
|         try: |         try: | ||||||
|             return CONTROLLERS[type(condition)](condition) |             return ConditionController.for_type(type(condition))(condition) | ||||||
|         except KeyError: |         except KeyError: | ||||||
|             return ConditionController() |             return ConditionController() | ||||||
| 
 | 
 | ||||||
|     SINGLE = True |  | ||||||
|     PLURAL = False |  | ||||||
|     NONE = True |  | ||||||
|     SOME = False |  | ||||||
|     MESSAGE = { |  | ||||||
|         NONE: { |  | ||||||
|             SINGLE: |  | ||||||
|                 "%(items)s is no longer available to you", |  | ||||||
|             PLURAL: |  | ||||||
|                 "%(items)s are no longer available to you", |  | ||||||
|         }, |  | ||||||
|         SOME: { |  | ||||||
|             SINGLE: |  | ||||||
|                 "Only %(remainder)d of the following item remains: %(items)s", |  | ||||||
|             PLURAL: |  | ||||||
|                 "Only %(remainder)d of the following items remain: %(items)s" |  | ||||||
|         }, |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def test_flags( |     def pre_filter(cls, queryset, user): | ||||||
|             cls, user, products=None, product_quantities=None): |         ''' Returns only the flag conditions that might be available for this | ||||||
|         ''' Evaluates all of the flag conditions on the given products. |         user. It should hopefully reduce the number of queries that need to be | ||||||
|  |         executed to determine if a flag is met. | ||||||
| 
 | 
 | ||||||
|         If `product_quantities` is supplied, the condition is only met if it |         If this filtration implements the same query as is_met, then you should | ||||||
|         will permit the sum of the product quantities for all of the products |         be able to implement ``is_met()`` in terms of this. | ||||||
|         it covers. Otherwise, it will be met if at least one item can be |  | ||||||
|         accepted. |  | ||||||
| 
 | 
 | ||||||
|         If all flag conditions pass, an empty list is returned, otherwise |         Arguments: | ||||||
|         a list is returned containing all of the products that are *not |  | ||||||
|         enabled*. ''' |  | ||||||
| 
 | 
 | ||||||
|         if products is not None and product_quantities is not None: |             queryset (Queryset[c]): The canditate conditions. | ||||||
|             raise ValueError("Please specify only products or " |  | ||||||
|                              "product_quantities") |  | ||||||
|         elif products is None: |  | ||||||
|             products = set(i[0] for i in product_quantities) |  | ||||||
|             quantities = dict((product, quantity) |  | ||||||
|                               for product, quantity in product_quantities) |  | ||||||
|         elif product_quantities is None: |  | ||||||
|             products = set(products) |  | ||||||
|             quantities = {} |  | ||||||
| 
 | 
 | ||||||
|         # Get the conditions covered by the products themselves |             user (User): The user for whom we're testing these conditions. | ||||||
|         prods = ( |  | ||||||
|             product.flagbase_set.select_subclasses() |  | ||||||
|             for product in products |  | ||||||
|         ) |  | ||||||
|         # Get the conditions covered by their categories |  | ||||||
|         cats = ( |  | ||||||
|             category.flagbase_set.select_subclasses() |  | ||||||
|             for category in set(product.category for product in products) |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         if products: |         Returns: | ||||||
|             # Simplify the query. |             Queryset[c]: A subset of the conditions that pass the pre-filter | ||||||
|             all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) |                 test for this user. | ||||||
|         else: |  | ||||||
|             all_conditions = [] |  | ||||||
| 
 | 
 | ||||||
|         # All disable-if-false conditions on a product need to be met |         ''' | ||||||
|         do_not_disable = defaultdict(lambda: True) |  | ||||||
|         # At least one enable-if-true condition on a product must be met |  | ||||||
|         do_enable = defaultdict(lambda: False) |  | ||||||
|         # (if either sort of condition is present) |  | ||||||
| 
 | 
 | ||||||
|         messages = {} |         # Default implementation does NOTHING. | ||||||
|  |         return queryset | ||||||
| 
 | 
 | ||||||
|         for condition in all_conditions: |     def passes_filter(self, user): | ||||||
|             cond = cls.for_condition(condition) |         ''' Returns true if the condition passes the filter ''' | ||||||
|             remainder = cond.user_quantity_remaining(user) |  | ||||||
| 
 | 
 | ||||||
|             # Get all products covered by this condition, and the products |         cls = type(self.condition) | ||||||
|             # from the categories covered by this condition |         qs = cls.objects.filter(pk=self.condition.id) | ||||||
|             cond_products = condition.products.all() |         return self.condition in self.pre_filter(qs, user) | ||||||
|             from_category = inventory.Product.objects.filter( |  | ||||||
|                 category__in=condition.categories.all(), |  | ||||||
|             ).all() |  | ||||||
|             all_products = cond_products | from_category |  | ||||||
|             all_products = all_products.select_related("category") |  | ||||||
|             # Remove the products that we aren't asking about |  | ||||||
|             all_products = [ |  | ||||||
|                 product |  | ||||||
|                 for product in all_products |  | ||||||
|                 if product in products |  | ||||||
|             ] |  | ||||||
| 
 | 
 | ||||||
|             if quantities: |     def user_quantity_remaining(self, user, filtered=False): | ||||||
|                 consumed = sum(quantities[i] for i in all_products) |  | ||||||
|             else: |  | ||||||
|                 consumed = 1 |  | ||||||
|             met = consumed <= remainder |  | ||||||
| 
 |  | ||||||
|             if not met: |  | ||||||
|                 items = ", ".join(str(product) for product in all_products) |  | ||||||
|                 base = cls.MESSAGE[remainder == 0][len(all_products) == 1] |  | ||||||
|                 message = base % {"items": items, "remainder": remainder} |  | ||||||
| 
 |  | ||||||
|             for product in all_products: |  | ||||||
|                 if condition.is_disable_if_false: |  | ||||||
|                     do_not_disable[product] &= met |  | ||||||
|                 else: |  | ||||||
|                     do_enable[product] |= met |  | ||||||
| 
 |  | ||||||
|                 if not met and product not in messages: |  | ||||||
|                     messages[product] = message |  | ||||||
| 
 |  | ||||||
|         valid = {} |  | ||||||
|         for product in itertools.chain(do_not_disable, do_enable): |  | ||||||
|             if product in do_enable: |  | ||||||
|                 # If there's an enable-if-true, we need need of those met too. |  | ||||||
|                 # (do_not_disable will default to true otherwise) |  | ||||||
|                 valid[product] = do_not_disable[product] and do_enable[product] |  | ||||||
|             elif product in do_not_disable: |  | ||||||
|                 # If there's a disable-if-false condition, all must be met |  | ||||||
|                 valid[product] = do_not_disable[product] |  | ||||||
| 
 |  | ||||||
|         error_fields = [ |  | ||||||
|             (product, messages[product]) |  | ||||||
|             for product in valid if not valid[product] |  | ||||||
|         ] |  | ||||||
| 
 |  | ||||||
|         return error_fields |  | ||||||
| 
 |  | ||||||
|     def user_quantity_remaining(self, user): |  | ||||||
|         ''' Returns the number of items covered by this flag condition the |         ''' Returns the number of items covered by this flag condition the | ||||||
|         user can add to the current cart. This default implementation returns |         user can add to the current cart. This default implementation returns | ||||||
|         a big number if is_met() is true, otherwise 0. |         a big number if is_met() is true, otherwise 0. | ||||||
|  | @ -180,144 +83,210 @@ class ConditionController(object): | ||||||
|         Either this method, or is_met() must be overridden in subclasses. |         Either this method, or is_met() must be overridden in subclasses. | ||||||
|         ''' |         ''' | ||||||
| 
 | 
 | ||||||
|         return 99999999 if self.is_met(user) else 0 |         return _BIG_QUANTITY if self.is_met(user, filtered) else 0 | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |     def is_met(self, user, filtered=False): | ||||||
|         ''' Returns True if this flag condition is met, otherwise returns |         ''' Returns True if this flag condition is met, otherwise returns | ||||||
|         False. |         False. | ||||||
| 
 | 
 | ||||||
|         Either this method, or user_quantity_remaining() must be overridden |         Either this method, or user_quantity_remaining() must be overridden | ||||||
|         in subclasses. |         in subclasses. | ||||||
|  | 
 | ||||||
|  |         Arguments: | ||||||
|  | 
 | ||||||
|  |             user (User): The user for whom this test must be met. | ||||||
|  | 
 | ||||||
|  |             filter (bool): If true, this condition was part of a queryset | ||||||
|  |                 returned by pre_filter() for this user. | ||||||
|  | 
 | ||||||
|         ''' |         ''' | ||||||
|         return self.user_quantity_remaining(user) > 0 |         return self.user_quantity_remaining(user, filtered) > 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CategoryConditionController(ConditionController): | class IsMetByFilter(object): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     def is_met(self, user, filtered=False): | ||||||
|         self.condition = condition |         ''' Returns True if this flag condition is met, otherwise returns | ||||||
|  |         False. It determines if the condition is met by calling pre_filter | ||||||
|  |         with a queryset containing only self.condition. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         if filtered: | ||||||
|         ''' returns True if the user has a product from a category that invokes |             return True  # Why query again? | ||||||
|         this condition in one of their carts ''' |  | ||||||
| 
 | 
 | ||||||
|         carts = commerce.Cart.objects.filter(user=user) |         return self.passes_filter(user) | ||||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) | 
 | ||||||
|         enabling_products = inventory.Product.objects.filter( | 
 | ||||||
|             category=self.condition.enabling_category, | class RemainderSetByFilter(object): | ||||||
|  | 
 | ||||||
|  |     def user_quantity_remaining(self, user, filtered=True): | ||||||
|  |         ''' returns 0 if the date range is violated, otherwise, it will return | ||||||
|  |         the quantity remaining under the stock limit. | ||||||
|  | 
 | ||||||
|  |         The filter for this condition must add an annotation called "remainder" | ||||||
|  |         in order for this to work. | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         if filtered: | ||||||
|  |             if hasattr(self.condition, "remainder"): | ||||||
|  |                 return self.condition.remainder | ||||||
|  | 
 | ||||||
|  |         # Mark self.condition with a remainder | ||||||
|  |         qs = type(self.condition).objects.filter(pk=self.condition.id) | ||||||
|  |         qs = self.pre_filter(qs, user) | ||||||
|  | 
 | ||||||
|  |         if len(qs) > 0: | ||||||
|  |             return qs[0].remainder | ||||||
|  |         else: | ||||||
|  |             return 0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CategoryConditionController(IsMetByFilter, ConditionController): | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has a | ||||||
|  |         product from a category invoking that item's condition in one of their | ||||||
|  |         carts. ''' | ||||||
|  | 
 | ||||||
|  |         in_user_carts = Q( | ||||||
|  |             enabling_category__product__productitem__cart__user=user | ||||||
|         ) |         ) | ||||||
|         products_count = commerce.ProductItem.objects.filter( |         released = commerce.Cart.STATUS_RELEASED | ||||||
|             cart__in=carts, |         in_released_carts = Q( | ||||||
|             product__in=enabling_products, |             enabling_category__product__productitem__cart__status=released | ||||||
|         ).count() |         ) | ||||||
|         return products_count > 0 |         queryset = queryset.filter(in_user_carts) | ||||||
|  |         queryset = queryset.exclude(in_released_carts) | ||||||
|  | 
 | ||||||
|  |         return queryset | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ProductConditionController(ConditionController): | class ProductConditionController(IsMetByFilter, ConditionController): | ||||||
|     ''' Condition tests for ProductFlag and |     ''' Condition tests for ProductFlag and | ||||||
|     IncludedProductDiscount. ''' |     IncludedProductDiscount. ''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     @classmethod | ||||||
|         self.condition = condition |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has a | ||||||
|  |         product invoking that item's condition in one of their carts. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         in_user_carts = Q(enabling_products__productitem__cart__user=user) | ||||||
|         ''' returns True if the user has a product that invokes this |         released = commerce.Cart.STATUS_RELEASED | ||||||
|         condition in one of their carts ''' |         in_released_carts = Q( | ||||||
|  |             enabling_products__productitem__cart__status=released | ||||||
|  |         ) | ||||||
|  |         queryset = queryset.filter(in_user_carts) | ||||||
|  |         queryset = queryset.exclude(in_released_carts) | ||||||
| 
 | 
 | ||||||
|         carts = commerce.Cart.objects.filter(user=user) |         return queryset | ||||||
|         carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED) |  | ||||||
|         products_count = commerce.ProductItem.objects.filter( |  | ||||||
|             cart__in=carts, |  | ||||||
|             product__in=self.condition.enabling_products.all(), |  | ||||||
|         ).count() |  | ||||||
|         return products_count > 0 |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitConditionController(ConditionController): | class TimeOrStockLimitConditionController( | ||||||
|  |             RemainderSetByFilter, | ||||||
|  |             ConditionController, | ||||||
|  |         ): | ||||||
|     ''' Common condition tests for TimeOrStockLimit Flag and |     ''' Common condition tests for TimeOrStockLimit Flag and | ||||||
|     Discount.''' |     Discount.''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, ceiling): |     @classmethod | ||||||
|         self.ceiling = ceiling |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the date falls into | ||||||
|  |         any specified range, but not yet where the stock limit is not yet | ||||||
|  |         reached.''' | ||||||
| 
 | 
 | ||||||
|     def user_quantity_remaining(self, user): |  | ||||||
|         ''' returns 0 if the date range is violated, otherwise, it will return |  | ||||||
|         the quantity remaining under the stock limit. ''' |  | ||||||
| 
 |  | ||||||
|         # Test date range |  | ||||||
|         if not self._test_date_range(): |  | ||||||
|             return 0 |  | ||||||
| 
 |  | ||||||
|         return self._get_remaining_stock(user) |  | ||||||
| 
 |  | ||||||
|     def _test_date_range(self): |  | ||||||
|         now = timezone.now() |         now = timezone.now() | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.start_time is not None: |         # Keep items with no start time, or start time not yet met. | ||||||
|             if now < self.ceiling.start_time: |         queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now)) | ||||||
|                 return False |         queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now)) | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.end_time is not None: |         # Filter out items that have been reserved beyond the limits | ||||||
|             if now > self.ceiling.end_time: |         quantity_or_zero = self._calculate_quantities(user) | ||||||
|                 return False |  | ||||||
| 
 | 
 | ||||||
|         return True |         remainder = Case( | ||||||
|  |             When(limit=None, then=Value(_BIG_QUANTITY)), | ||||||
|  |             default=F("limit") - Sum(quantity_or_zero), | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def _get_remaining_stock(self, user): |         queryset = queryset.annotate(remainder=remainder) | ||||||
|         ''' Returns the stock that remains under this ceiling, excluding the |         queryset = queryset.filter(remainder__gt=0) | ||||||
|         user's current cart. ''' |  | ||||||
| 
 | 
 | ||||||
|         if self.ceiling.limit is None: |         return queryset | ||||||
|             return 99999999 |  | ||||||
| 
 | 
 | ||||||
|         # We care about all reserved carts, but not the user's current cart |     @classmethod | ||||||
|  |     def _relevant_carts(cls, user): | ||||||
|         reserved_carts = commerce.Cart.reserved_carts() |         reserved_carts = commerce.Cart.reserved_carts() | ||||||
|         reserved_carts = reserved_carts.exclude( |         reserved_carts = reserved_carts.exclude( | ||||||
|             user=user, |             user=user, | ||||||
|             status=commerce.Cart.STATUS_ACTIVE, |             status=commerce.Cart.STATUS_ACTIVE, | ||||||
|         ) |         ) | ||||||
| 
 |         return reserved_carts | ||||||
|         items = self._items() |  | ||||||
|         items = items.filter(cart__in=reserved_carts) |  | ||||||
|         count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 |  | ||||||
| 
 |  | ||||||
|         return self.ceiling.limit - count |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitFlagController( | class TimeOrStockLimitFlagController( | ||||||
|         TimeOrStockLimitConditionController): |         TimeOrStockLimitConditionController): | ||||||
| 
 | 
 | ||||||
|     def _items(self): |     @classmethod | ||||||
|         category_products = inventory.Product.objects.filter( |     def _calculate_quantities(cls, user): | ||||||
|             category__in=self.ceiling.categories.all(), |         reserved_carts = cls._relevant_carts(user) | ||||||
|         ) |  | ||||||
|         products = self.ceiling.products.all() | category_products |  | ||||||
| 
 | 
 | ||||||
|         product_items = commerce.ProductItem.objects.filter( |         # Calculate category lines | ||||||
|             product__in=products.all(), |         item_cats = F('categories__product__productitem__product__category') | ||||||
|  |         reserved_category_products = ( | ||||||
|  |             Q(categories=item_cats) & | ||||||
|  |             Q(categories__product__productitem__cart__in=reserved_carts) | ||||||
|         ) |         ) | ||||||
|         return product_items | 
 | ||||||
|  |         # Calculate product lines | ||||||
|  |         reserved_products = ( | ||||||
|  |             Q(products=F('products__productitem__product')) & | ||||||
|  |             Q(products__productitem__cart__in=reserved_carts) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         category_quantity_in_reserved_carts = When( | ||||||
|  |             reserved_category_products, | ||||||
|  |             then="categories__product__productitem__quantity", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         product_quantity_in_reserved_carts = When( | ||||||
|  |             reserved_products, | ||||||
|  |             then="products__productitem__quantity", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             category_quantity_in_reserved_carts, | ||||||
|  |             product_quantity_in_reserved_carts, | ||||||
|  |             default=Value(0), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return quantity_or_zero | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): | ||||||
| 
 | 
 | ||||||
|     def _items(self): |     @classmethod | ||||||
|         discount_items = commerce.DiscountItem.objects.filter( |     def _calculate_quantities(cls, user): | ||||||
|             discount=self.ceiling, |         reserved_carts = cls._relevant_carts(user) | ||||||
|  | 
 | ||||||
|  |         quantity_in_reserved_carts = When( | ||||||
|  |             discountitem__cart__in=reserved_carts, | ||||||
|  |             then="discountitem__quantity" | ||||||
|         ) |         ) | ||||||
|         return discount_items | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             quantity_in_reserved_carts, | ||||||
|  |             default=Value(0) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return quantity_or_zero | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class VoucherConditionController(ConditionController): | class VoucherConditionController(IsMetByFilter, ConditionController): | ||||||
|     ''' Condition test for VoucherFlag and VoucherDiscount.''' |     ''' Condition test for VoucherFlag and VoucherDiscount.''' | ||||||
| 
 | 
 | ||||||
|     def __init__(self, condition): |     @classmethod | ||||||
|         self.condition = condition |     def pre_filter(self, queryset, user): | ||||||
|  |         ''' Returns all of the items from queryset where the user has entered | ||||||
|  |         a voucher that invokes that item's condition in one of their carts. ''' | ||||||
| 
 | 
 | ||||||
|     def is_met(self, user): |         return queryset.filter(voucher__cart__user=user) | ||||||
|         ''' returns True if the user has the given voucher attached. ''' |  | ||||||
|         carts_count = commerce.Cart.objects.filter( |  | ||||||
|             user=user, |  | ||||||
|             vouchers=self.condition.voucher, |  | ||||||
|         ).count() |  | ||||||
|         return carts_count > 0 |  | ||||||
|  |  | ||||||
|  | @ -4,7 +4,11 @@ from conditions import ConditionController | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import F, Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import Value | ||||||
|  | from django.db.models import When | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DiscountAndQuantity(object): | class DiscountAndQuantity(object): | ||||||
|  | @ -38,80 +42,158 @@ class DiscountAndQuantity(object): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def available_discounts(user, categories, products): | class DiscountController(object): | ||||||
|     ''' Returns all discounts available to this user for the given categories |  | ||||||
|     and products. The discounts also list the available quantity for this user, |  | ||||||
|     not including products that are pending purchase. ''' |  | ||||||
| 
 | 
 | ||||||
|     # discounts that match provided categories |     @classmethod | ||||||
|     category_discounts = conditions.DiscountForCategory.objects.filter( |     def available_discounts(cls, user, categories, products): | ||||||
|         category__in=categories |         ''' Returns all discounts available to this user for the given | ||||||
|     ) |         categories and products. The discounts also list the available quantity | ||||||
|     # discounts that match provided products |         for this user, not including products that are pending purchase. ''' | ||||||
|     product_discounts = conditions.DiscountForProduct.objects.filter( |  | ||||||
|         product__in=products |  | ||||||
|     ) |  | ||||||
|     # discounts that match categories for provided products |  | ||||||
|     product_category_discounts = conditions.DiscountForCategory.objects.filter( |  | ||||||
|         category__in=(product.category for product in products) |  | ||||||
|     ) |  | ||||||
|     # (Not relevant: discounts that match products in provided categories) |  | ||||||
| 
 | 
 | ||||||
|     product_discounts = product_discounts.select_related( |         filtered_clauses = cls._filtered_discounts(user, categories, products) | ||||||
|         "product", |  | ||||||
|         "product__category", |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     all_category_discounts = category_discounts | product_category_discounts |         discounts = [] | ||||||
|     all_category_discounts = all_category_discounts.select_related( |  | ||||||
|         "category", |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
|     # The set of all potential discounts |         # Markers so that we don't need to evaluate given conditions | ||||||
|     potential_discounts = set(itertools.chain( |         # more than once | ||||||
|         product_discounts, |         accepted_discounts = set() | ||||||
|         all_category_discounts, |         failed_discounts = set() | ||||||
|     )) |  | ||||||
| 
 | 
 | ||||||
|     discounts = [] |         for clause in filtered_clauses: | ||||||
|  |             discount = clause.discount | ||||||
|  |             cond = ConditionController.for_condition(discount) | ||||||
| 
 | 
 | ||||||
|     # Markers so that we don't need to evaluate given conditions more than once |             past_use_count = clause.past_use_count | ||||||
|     accepted_discounts = set() |             if past_use_count >= clause.quantity: | ||||||
|     failed_discounts = set() |                 # This clause has exceeded its use count | ||||||
|  |                 pass | ||||||
|  |             elif discount not in failed_discounts: | ||||||
|  |                 # This clause is still available | ||||||
|  |                 is_accepted = discount in accepted_discounts | ||||||
|  |                 if is_accepted or cond.is_met(user, filtered=True): | ||||||
|  |                     # This clause is valid for this user | ||||||
|  |                     discounts.append(DiscountAndQuantity( | ||||||
|  |                         discount=discount, | ||||||
|  |                         clause=clause, | ||||||
|  |                         quantity=clause.quantity - past_use_count, | ||||||
|  |                     )) | ||||||
|  |                     accepted_discounts.add(discount) | ||||||
|  |                 else: | ||||||
|  |                     # This clause is not valid for this user | ||||||
|  |                     failed_discounts.add(discount) | ||||||
|  |         return discounts | ||||||
| 
 | 
 | ||||||
|     for discount in potential_discounts: |     @classmethod | ||||||
|         real_discount = conditions.DiscountBase.objects.get_subclass( |     def _filtered_discounts(cls, user, categories, products): | ||||||
|             pk=discount.discount.pk, |         ''' | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Sequence[discountbase]: All discounts that passed the filter | ||||||
|  |             function. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         types = list(ConditionController._controllers()) | ||||||
|  |         discounttypes = [ | ||||||
|  |             i for i in types if issubclass(i, conditions.DiscountBase) | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         # discounts that match provided categories | ||||||
|  |         category_discounts = conditions.DiscountForCategory.objects.filter( | ||||||
|  |             category__in=categories | ||||||
|         ) |         ) | ||||||
|         cond = ConditionController.for_condition(real_discount) |         # discounts that match provided products | ||||||
| 
 |         product_discounts = conditions.DiscountForProduct.objects.filter( | ||||||
|         # Count the past uses of the given discount item. |             product__in=products | ||||||
|         # If this user has exceeded the limit for the clause, this clause |  | ||||||
|         # is not available any more. |  | ||||||
|         past_uses = commerce.DiscountItem.objects.filter( |  | ||||||
|             cart__user=user, |  | ||||||
|             cart__status=commerce.Cart.STATUS_PAID,  # Only past carts count |  | ||||||
|             discount=real_discount, |  | ||||||
|         ) |         ) | ||||||
|         agg = past_uses.aggregate(Sum("quantity")) |         # discounts that match categories for provided products | ||||||
|         past_use_count = agg["quantity__sum"] |         product_category_discounts = conditions.DiscountForCategory.objects | ||||||
|         if past_use_count is None: |         product_category_discounts = product_category_discounts.filter( | ||||||
|             past_use_count = 0 |             category__in=(product.category for product in products) | ||||||
|  |         ) | ||||||
|  |         # (Not relevant: discounts that match products in provided categories) | ||||||
| 
 | 
 | ||||||
|         if past_use_count >= discount.quantity: |         product_discounts = product_discounts.select_related( | ||||||
|             # This clause has exceeded its use count |             "product", | ||||||
|             pass |             "product__category", | ||||||
|         elif real_discount not in failed_discounts: |         ) | ||||||
|             # This clause is still available | 
 | ||||||
|             if real_discount in accepted_discounts or cond.is_met(user): |         all_category_discounts = ( | ||||||
|                 # This clause is valid for this user |             category_discounts | product_category_discounts | ||||||
|                 discounts.append(DiscountAndQuantity( |         ) | ||||||
|                     discount=real_discount, |         all_category_discounts = all_category_discounts.select_related( | ||||||
|                     clause=discount, |             "category", | ||||||
|                     quantity=discount.quantity - past_use_count, |         ) | ||||||
|                 )) | 
 | ||||||
|                 accepted_discounts.add(real_discount) |         valid_discounts = conditions.DiscountBase.objects.filter( | ||||||
|             else: |             Q(discountforproduct__in=product_discounts) | | ||||||
|                 # This clause is not valid for this user |             Q(discountforcategory__in=all_category_discounts) | ||||||
|                 failed_discounts.add(real_discount) |         ) | ||||||
|     return discounts | 
 | ||||||
|  |         all_subsets = [] | ||||||
|  | 
 | ||||||
|  |         for discounttype in discounttypes: | ||||||
|  |             discounts = discounttype.objects.filter(id__in=valid_discounts) | ||||||
|  |             ctrl = ConditionController.for_type(discounttype) | ||||||
|  |             discounts = ctrl.pre_filter(discounts, user) | ||||||
|  |             all_subsets.append(discounts) | ||||||
|  | 
 | ||||||
|  |         filtered_discounts = list(itertools.chain(*all_subsets)) | ||||||
|  | 
 | ||||||
|  |         # Map from discount key to itself | ||||||
|  |         # (contains annotations needed in the future) | ||||||
|  |         from_filter = dict((i.id, i) for i in filtered_discounts) | ||||||
|  | 
 | ||||||
|  |         clause_sets = ( | ||||||
|  |             product_discounts.filter(discount__in=filtered_discounts), | ||||||
|  |             all_category_discounts.filter(discount__in=filtered_discounts), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         clause_sets = ( | ||||||
|  |             cls._annotate_with_past_uses(i, user) for i in clause_sets | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # The set of all potential discount clauses | ||||||
|  |         discount_clauses = set(itertools.chain(*clause_sets)) | ||||||
|  | 
 | ||||||
|  |         # Replace discounts with the filtered ones | ||||||
|  |         # These are the correct subclasses (saves query later on), and have | ||||||
|  |         # correct annotations from filters if necessary. | ||||||
|  |         for clause in discount_clauses: | ||||||
|  |             clause.discount = from_filter[clause.discount.id] | ||||||
|  | 
 | ||||||
|  |         return discount_clauses | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def _annotate_with_past_uses(cls, queryset, user): | ||||||
|  |         ''' Annotates the queryset with a usage count for that discount claus | ||||||
|  |         by the given user. ''' | ||||||
|  | 
 | ||||||
|  |         if queryset.model == conditions.DiscountForCategory: | ||||||
|  |             matches = ( | ||||||
|  |                 Q(category=F('discount__discountitem__product__category')) | ||||||
|  |             ) | ||||||
|  |         elif queryset.model == conditions.DiscountForProduct: | ||||||
|  |             matches = ( | ||||||
|  |                 Q(product=F('discount__discountitem__product')) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         in_carts = ( | ||||||
|  |             Q(discount__discountitem__cart__user=user) & | ||||||
|  |             Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         past_use_quantity = When( | ||||||
|  |             in_carts & matches, | ||||||
|  |             then="discount__discountitem__quantity", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         past_use_quantity_or_zero = Case( | ||||||
|  |             past_use_quantity, | ||||||
|  |             default=Value(0), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         queryset = queryset.annotate( | ||||||
|  |             past_use_count=Sum(past_use_quantity_or_zero) | ||||||
|  |         ) | ||||||
|  |         return queryset | ||||||
|  |  | ||||||
							
								
								
									
										264
									
								
								registrasion/controllers/flag.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								registrasion/controllers/flag.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,264 @@ | ||||||
|  | import itertools | ||||||
|  | import operator | ||||||
|  | 
 | ||||||
|  | from collections import defaultdict | ||||||
|  | from collections import namedtuple | ||||||
|  | from django.db.models import Count | ||||||
|  | from django.db.models import Q | ||||||
|  | 
 | ||||||
|  | from .conditions import ConditionController | ||||||
|  | 
 | ||||||
|  | from registrasion.models import conditions | ||||||
|  | from registrasion.models import inventory | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class FlagController(object): | ||||||
|  | 
 | ||||||
|  |     SINGLE = True | ||||||
|  |     PLURAL = False | ||||||
|  |     NONE = True | ||||||
|  |     SOME = False | ||||||
|  |     MESSAGE = { | ||||||
|  |         NONE: { | ||||||
|  |             SINGLE: | ||||||
|  |                 "%(items)s is no longer available to you", | ||||||
|  |             PLURAL: | ||||||
|  |                 "%(items)s are no longer available to you", | ||||||
|  |         }, | ||||||
|  |         SOME: { | ||||||
|  |             SINGLE: | ||||||
|  |                 "Only %(remainder)d of the following item remains: %(items)s", | ||||||
|  |             PLURAL: | ||||||
|  |                 "Only %(remainder)d of the following items remain: %(items)s" | ||||||
|  |         }, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def test_flags( | ||||||
|  |             cls, user, products=None, product_quantities=None): | ||||||
|  |         ''' Evaluates all of the flag conditions on the given products. | ||||||
|  | 
 | ||||||
|  |         If `product_quantities` is supplied, the condition is only met if it | ||||||
|  |         will permit the sum of the product quantities for all of the products | ||||||
|  |         it covers. Otherwise, it will be met if at least one item can be | ||||||
|  |         accepted. | ||||||
|  | 
 | ||||||
|  |         If all flag conditions pass, an empty list is returned, otherwise | ||||||
|  |         a list is returned containing all of the products that are *not | ||||||
|  |         enabled*. ''' | ||||||
|  | 
 | ||||||
|  |         print "GREPME: test_flags()" | ||||||
|  | 
 | ||||||
|  |         if products is not None and product_quantities is not None: | ||||||
|  |             raise ValueError("Please specify only products or " | ||||||
|  |                              "product_quantities") | ||||||
|  |         elif products is None: | ||||||
|  |             products = set(i[0] for i in product_quantities) | ||||||
|  |             quantities = dict((product, quantity) | ||||||
|  |                               for product, quantity in product_quantities) | ||||||
|  |         elif product_quantities is None: | ||||||
|  |             products = set(products) | ||||||
|  |             quantities = {} | ||||||
|  | 
 | ||||||
|  |         if products: | ||||||
|  |             # Simplify the query. | ||||||
|  |             all_conditions = cls._filtered_flags(user, products) | ||||||
|  |         else: | ||||||
|  |             all_conditions = [] | ||||||
|  | 
 | ||||||
|  |         # All disable-if-false conditions on a product need to be met | ||||||
|  |         do_not_disable = defaultdict(lambda: True) | ||||||
|  |         # At least one enable-if-true condition on a product must be met | ||||||
|  |         do_enable = defaultdict(lambda: False) | ||||||
|  |         # (if either sort of condition is present) | ||||||
|  | 
 | ||||||
|  |         # Count the number of conditions for a product | ||||||
|  |         dif_count = defaultdict(int) | ||||||
|  |         eit_count = defaultdict(int) | ||||||
|  | 
 | ||||||
|  |         messages = {} | ||||||
|  | 
 | ||||||
|  |         for condition in all_conditions: | ||||||
|  |             cond = ConditionController.for_condition(condition) | ||||||
|  |             remainder = cond.user_quantity_remaining(user, filtered=True) | ||||||
|  | 
 | ||||||
|  |             # Get all products covered by this condition, and the products | ||||||
|  |             # from the categories covered by this condition | ||||||
|  | 
 | ||||||
|  |             ids = [product.id for product in products] | ||||||
|  |             all_products = inventory.Product.objects.filter(id__in=ids) | ||||||
|  |             cond = ( | ||||||
|  |                 Q(flagbase_set=condition) | | ||||||
|  |                 Q(category__in=condition.categories.all()) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             all_products = all_products.filter(cond) | ||||||
|  |             all_products = all_products.select_related("category") | ||||||
|  | 
 | ||||||
|  |             if quantities: | ||||||
|  |                 consumed = sum(quantities[i] for i in all_products) | ||||||
|  |             else: | ||||||
|  |                 consumed = 1 | ||||||
|  |             met = consumed <= remainder | ||||||
|  | 
 | ||||||
|  |             if not met: | ||||||
|  |                 items = ", ".join(str(product) for product in all_products) | ||||||
|  |                 base = cls.MESSAGE[remainder == 0][len(all_products) == 1] | ||||||
|  |                 message = base % {"items": items, "remainder": remainder} | ||||||
|  | 
 | ||||||
|  |             for product in all_products: | ||||||
|  |                 if condition.is_disable_if_false: | ||||||
|  |                     do_not_disable[product] &= met | ||||||
|  |                     dif_count[product] += 1 | ||||||
|  |                 else: | ||||||
|  |                     do_enable[product] |= met | ||||||
|  |                     eit_count[product] += 1 | ||||||
|  | 
 | ||||||
|  |                 if not met and product not in messages: | ||||||
|  |                     messages[product] = message | ||||||
|  | 
 | ||||||
|  |         total_flags = FlagCounter.count() | ||||||
|  | 
 | ||||||
|  |         valid = {} | ||||||
|  | 
 | ||||||
|  |         # the problem is that now, not every condition falls into | ||||||
|  |         # do_not_disable or do_enable ''' | ||||||
|  |         # You should look into this, chris :) | ||||||
|  | 
 | ||||||
|  |         for product in products: | ||||||
|  |             if quantities: | ||||||
|  |                 if quantities[product] == 0: | ||||||
|  |                     continue | ||||||
|  | 
 | ||||||
|  |             f = total_flags.get(product) | ||||||
|  |             if f.dif > 0 and f.dif != dif_count[product]: | ||||||
|  |                 do_not_disable[product] = False | ||||||
|  |                 if product not in messages: | ||||||
|  |                     messages[product] = "Some disable-if-false " \ | ||||||
|  |                                         "conditions were not met" | ||||||
|  |             if f.eit > 0 and product not in do_enable: | ||||||
|  |                 do_enable[product] = False | ||||||
|  |                 if product not in messages: | ||||||
|  |                     messages[product] = "Some enable-if-true " \ | ||||||
|  |                                         "conditions were not met" | ||||||
|  | 
 | ||||||
|  |         for product in itertools.chain(do_not_disable, do_enable): | ||||||
|  |             f = total_flags.get(product) | ||||||
|  |             if product in do_enable: | ||||||
|  |                 # If there's an enable-if-true, we need need of those met too. | ||||||
|  |                 # (do_not_disable will default to true otherwise) | ||||||
|  |                 valid[product] = do_not_disable[product] and do_enable[product] | ||||||
|  |             elif product in do_not_disable: | ||||||
|  |                 # If there's a disable-if-false condition, all must be met | ||||||
|  |                 valid[product] = do_not_disable[product] | ||||||
|  | 
 | ||||||
|  |         error_fields = [ | ||||||
|  |             (product, messages[product]) | ||||||
|  |             for product in valid if not valid[product] | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         return error_fields | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def _filtered_flags(cls, user, products): | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Sequence[flagbase]: All flags that passed the filter function. | ||||||
|  | 
 | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         types = list(ConditionController._controllers()) | ||||||
|  |         flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)] | ||||||
|  | 
 | ||||||
|  |         # Get all flags for the products and categories. | ||||||
|  |         prods = ( | ||||||
|  |             product.flagbase_set.all() | ||||||
|  |             for product in products | ||||||
|  |         ) | ||||||
|  |         cats = ( | ||||||
|  |             category.flagbase_set.all() | ||||||
|  |             for category in set(product.category for product in products) | ||||||
|  |         ) | ||||||
|  |         all_flags = reduce(operator.or_, itertools.chain(prods, cats)) | ||||||
|  | 
 | ||||||
|  |         all_subsets = [] | ||||||
|  | 
 | ||||||
|  |         for flagtype in flagtypes: | ||||||
|  |             flags = flagtype.objects.filter(id__in=all_flags) | ||||||
|  |             ctrl = ConditionController.for_type(flagtype) | ||||||
|  |             flags = ctrl.pre_filter(flags, user) | ||||||
|  |             all_subsets.append(flags) | ||||||
|  | 
 | ||||||
|  |         return itertools.chain(*all_subsets) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ConditionAndRemainder = namedtuple( | ||||||
|  |     "ConditionAndRemainder", | ||||||
|  |     ( | ||||||
|  |         "condition", | ||||||
|  |         "remainder", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _FlagCounter = namedtuple( | ||||||
|  |     "_FlagCounter", | ||||||
|  |     ( | ||||||
|  |         "products", | ||||||
|  |         "categories", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _ConditionsCount = namedtuple( | ||||||
|  |     "ConditionsCount", | ||||||
|  |     ( | ||||||
|  |         "dif", | ||||||
|  |         "eit", | ||||||
|  |     ), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # TODO: this should be cacheable. | ||||||
|  | class FlagCounter(_FlagCounter): | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def count(cls): | ||||||
|  |         # Get the count of how many conditions should exist per product | ||||||
|  |         flagbases = conditions.FlagBase.objects | ||||||
|  | 
 | ||||||
|  |         types = ( | ||||||
|  |             conditions.FlagBase.ENABLE_IF_TRUE, | ||||||
|  |             conditions.FlagBase.DISABLE_IF_FALSE, | ||||||
|  |         ) | ||||||
|  |         keys = ("eit", "dif") | ||||||
|  |         flags = [ | ||||||
|  |             flagbases.filter( | ||||||
|  |                 condition=condition_type | ||||||
|  |             ).values( | ||||||
|  |                 'products', 'categories' | ||||||
|  |             ).annotate( | ||||||
|  |                 count=Count('id') | ||||||
|  |             ) | ||||||
|  |             for condition_type in types | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         cats = defaultdict(lambda: defaultdict(int)) | ||||||
|  |         prod = defaultdict(lambda: defaultdict(int)) | ||||||
|  | 
 | ||||||
|  |         for key, flagcounts in zip(keys, flags): | ||||||
|  |             for row in flagcounts: | ||||||
|  |                 if row["products"] is not None: | ||||||
|  |                     prod[row["products"]][key] = row["count"] | ||||||
|  |                 if row["categories"] is not None: | ||||||
|  |                     cats[row["categories"]][key] = row["count"] | ||||||
|  | 
 | ||||||
|  |         return cls(products=prod, categories=cats) | ||||||
|  | 
 | ||||||
|  |     def get(self, product): | ||||||
|  |         p = self.products[product.id] | ||||||
|  |         c = self.categories[product.category.id] | ||||||
|  |         eit = p["eit"] + c["eit"] | ||||||
|  |         dif = p["dif"] + c["dif"] | ||||||
|  |         return _ConditionsCount(dif=dif, eit=eit) | ||||||
|  | @ -29,6 +29,7 @@ class InvoiceController(ForId, object): | ||||||
|         If such an invoice does not exist, the cart is validated, and if valid, |         If such an invoice does not exist, the cart is validated, and if valid, | ||||||
|         an invoice is generated.''' |         an invoice is generated.''' | ||||||
| 
 | 
 | ||||||
|  |         cart.refresh_from_db() | ||||||
|         try: |         try: | ||||||
|             invoice = commerce.Invoice.objects.exclude( |             invoice = commerce.Invoice.objects.exclude( | ||||||
|                 status=commerce.Invoice.STATUS_VOID, |                 status=commerce.Invoice.STATUS_VOID, | ||||||
|  | @ -74,6 +75,8 @@ class InvoiceController(ForId, object): | ||||||
|     def _generate(cls, cart): |     def _generate(cls, cart): | ||||||
|         ''' Generates an invoice for the given cart. ''' |         ''' Generates an invoice for the given cart. ''' | ||||||
| 
 | 
 | ||||||
|  |         cart.refresh_from_db() | ||||||
|  | 
 | ||||||
|         issued = timezone.now() |         issued = timezone.now() | ||||||
|         reservation_limit = cart.reservation_duration + cart.time_last_updated |         reservation_limit = cart.reservation_duration + cart.time_last_updated | ||||||
|         # Never generate a due time that is before the issue time |         # Never generate a due time that is before the issue time | ||||||
|  | @ -96,6 +99,10 @@ class InvoiceController(ForId, object): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         product_items = commerce.ProductItem.objects.filter(cart=cart) |         product_items = commerce.ProductItem.objects.filter(cart=cart) | ||||||
|  |         product_items = product_items.select_related( | ||||||
|  |             "product", | ||||||
|  |             "product__category", | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if len(product_items) == 0: |         if len(product_items) == 0: | ||||||
|             raise ValidationError("Your cart is empty.") |             raise ValidationError("Your cart is empty.") | ||||||
|  | @ -103,29 +110,41 @@ class InvoiceController(ForId, object): | ||||||
|         product_items = product_items.order_by( |         product_items = product_items.order_by( | ||||||
|             "product__category__order", "product__order" |             "product__category__order", "product__order" | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|         discount_items = commerce.DiscountItem.objects.filter(cart=cart) |         discount_items = commerce.DiscountItem.objects.filter(cart=cart) | ||||||
|  |         discount_items = discount_items.select_related( | ||||||
|  |             "discount", | ||||||
|  |             "product", | ||||||
|  |             "product__category", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         line_items = [] | ||||||
|  | 
 | ||||||
|         invoice_value = Decimal() |         invoice_value = Decimal() | ||||||
|         for item in product_items: |         for item in product_items: | ||||||
|             product = item.product |             product = item.product | ||||||
|             line_item = commerce.LineItem.objects.create( |             line_item = commerce.LineItem( | ||||||
|                 invoice=invoice, |                 invoice=invoice, | ||||||
|                 description="%s - %s" % (product.category.name, product.name), |                 description="%s - %s" % (product.category.name, product.name), | ||||||
|                 quantity=item.quantity, |                 quantity=item.quantity, | ||||||
|                 price=product.price, |                 price=product.price, | ||||||
|                 product=product, |                 product=product, | ||||||
|             ) |             ) | ||||||
|  |             line_items.append(line_item) | ||||||
|             invoice_value += line_item.quantity * line_item.price |             invoice_value += line_item.quantity * line_item.price | ||||||
| 
 |  | ||||||
|         for item in discount_items: |         for item in discount_items: | ||||||
|             line_item = commerce.LineItem.objects.create( |             line_item = commerce.LineItem( | ||||||
|                 invoice=invoice, |                 invoice=invoice, | ||||||
|                 description=item.discount.description, |                 description=item.discount.description, | ||||||
|                 quantity=item.quantity, |                 quantity=item.quantity, | ||||||
|                 price=cls.resolve_discount_value(item) * -1, |                 price=cls.resolve_discount_value(item) * -1, | ||||||
|                 product=item.product, |                 product=item.product, | ||||||
|             ) |             ) | ||||||
|  |             line_items.append(line_item) | ||||||
|             invoice_value += line_item.quantity * line_item.price |             invoice_value += line_item.quantity * line_item.price | ||||||
| 
 | 
 | ||||||
|  |         commerce.LineItem.objects.bulk_create(line_items) | ||||||
|  | 
 | ||||||
|         invoice.value = invoice_value |         invoice.value = invoice_value | ||||||
| 
 | 
 | ||||||
|         invoice.save() |         invoice.save() | ||||||
|  | @ -251,6 +270,9 @@ class InvoiceController(ForId, object): | ||||||
|     def _invoice_matches_cart(self): |     def _invoice_matches_cart(self): | ||||||
|         ''' Returns true if there is no cart, or if the revision of this |         ''' Returns true if there is no cart, or if the revision of this | ||||||
|         invoice matches the current revision of the cart. ''' |         invoice matches the current revision of the cart. ''' | ||||||
|  | 
 | ||||||
|  |         self._refresh() | ||||||
|  | 
 | ||||||
|         cart = self.invoice.cart |         cart = self.invoice.cart | ||||||
|         if not cart: |         if not cart: | ||||||
|             return True |             return True | ||||||
|  |  | ||||||
|  | @ -1,11 +1,16 @@ | ||||||
| import itertools | import itertools | ||||||
| 
 | 
 | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import F, Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import When | ||||||
|  | from django.db.models import Value | ||||||
|  | 
 | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import inventory | from registrasion.models import inventory | ||||||
| 
 | 
 | ||||||
| from category import CategoryController | from .category import CategoryController | ||||||
| from conditions import ConditionController | from .flag import FlagController | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ProductController(object): | class ProductController(object): | ||||||
|  | @ -16,9 +21,7 @@ class ProductController(object): | ||||||
|     @classmethod |     @classmethod | ||||||
|     def available_products(cls, user, category=None, products=None): |     def available_products(cls, user, category=None, products=None): | ||||||
|         ''' Returns a list of all of the products that are available per |         ''' Returns a list of all of the products that are available per | ||||||
|         flag conditions from the given categories. |         flag conditions from the given categories. ''' | ||||||
|         TODO: refactor so that all conditions are tested here and |  | ||||||
|         can_add_with_flags calls this method. ''' |  | ||||||
|         if category is None and products is None: |         if category is None and products is None: | ||||||
|             raise ValueError("You must provide products or a category") |             raise ValueError("You must provide products or a category") | ||||||
| 
 | 
 | ||||||
|  | @ -31,22 +34,21 @@ class ProductController(object): | ||||||
|         if products is not None: |         if products is not None: | ||||||
|             all_products = set(itertools.chain(all_products, products)) |             all_products = set(itertools.chain(all_products, products)) | ||||||
| 
 | 
 | ||||||
|         cat_quants = dict( |         categories = set(product.category for product in all_products) | ||||||
|             ( |         r = CategoryController.attach_user_remainders(user, categories) | ||||||
|                 category, |         cat_quants = dict((c, c) for c in r) | ||||||
|                 CategoryController(category).user_quantity_remaining(user), | 
 | ||||||
|             ) |         r = ProductController.attach_user_remainders(user, all_products) | ||||||
|             for category in set(product.category for product in all_products) |         prod_quants = dict((p, p) for p in r) | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         passed_limits = set( |         passed_limits = set( | ||||||
|             product |             product | ||||||
|             for product in all_products |             for product in all_products | ||||||
|             if cat_quants[product.category] > 0 |             if cat_quants[product.category].remainder > 0 | ||||||
|             if cls(product).user_quantity_remaining(user) > 0 |             if prod_quants[product].remainder > 0 | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         failed_and_messages = ConditionController.test_flags( |         failed_and_messages = FlagController.test_flags( | ||||||
|             user, products=passed_limits |             user, products=passed_limits | ||||||
|         ) |         ) | ||||||
|         failed_conditions = set(i[0] for i in failed_and_messages) |         failed_conditions = set(i[0] for i in failed_and_messages) | ||||||
|  | @ -56,26 +58,47 @@ class ProductController(object): | ||||||
| 
 | 
 | ||||||
|         return out |         return out | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def attach_user_remainders(cls, user, products): | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         Return: | ||||||
|  |             queryset(inventory.Product): A queryset containing items from | ||||||
|  |             ``product``, with an extra attribute -- remainder = the amount of | ||||||
|  |             this item that is remaining. | ||||||
|  |         ''' | ||||||
|  | 
 | ||||||
|  |         ids = [product.id for product in products] | ||||||
|  |         products = inventory.Product.objects.filter(id__in=ids) | ||||||
|  | 
 | ||||||
|  |         cart_filter = ( | ||||||
|  |             Q(productitem__cart__user=user) & | ||||||
|  |             Q(productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity = When( | ||||||
|  |             cart_filter, | ||||||
|  |             then='productitem__quantity' | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         quantity_or_zero = Case( | ||||||
|  |             quantity, | ||||||
|  |             default=Value(0), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         remainder = Case( | ||||||
|  |             When(limit_per_user=None, then=Value(99999999)), | ||||||
|  |             default=F('limit_per_user') - Sum(quantity_or_zero), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         products = products.annotate(remainder=remainder) | ||||||
|  | 
 | ||||||
|  |         return products | ||||||
|  | 
 | ||||||
|     def user_quantity_remaining(self, user): |     def user_quantity_remaining(self, user): | ||||||
|         ''' Returns the quantity of this product that the user add in the |         ''' Returns the quantity of this product that the user add in the | ||||||
|         current cart. ''' |         current cart. ''' | ||||||
| 
 | 
 | ||||||
|         prod_limit = self.product.limit_per_user |         with_remainders = self.attach_user_remainders(user, [self.product]) | ||||||
| 
 | 
 | ||||||
|         if prod_limit is None: |         return with_remainders[0].remainder | ||||||
|             # Don't need to run the remaining queries |  | ||||||
|             return 999999  # We can do better |  | ||||||
| 
 |  | ||||||
|         carts = commerce.Cart.objects.filter( |  | ||||||
|             user=user, |  | ||||||
|             status=commerce.Cart.STATUS_PAID, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         items = commerce.ProductItem.objects.filter( |  | ||||||
|             cart__in=carts, |  | ||||||
|             product=self.product, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0 |  | ||||||
| 
 |  | ||||||
|         return prod_limit - prod_count |  | ||||||
|  |  | ||||||
|  | @ -4,7 +4,11 @@ from registrasion.controllers.category import CategoryController | ||||||
| 
 | 
 | ||||||
| from collections import namedtuple | from collections import namedtuple | ||||||
| from django import template | from django import template | ||||||
|  | from django.db.models import Case | ||||||
|  | from django.db.models import Q | ||||||
| from django.db.models import Sum | from django.db.models import Sum | ||||||
|  | from django.db.models import When | ||||||
|  | from django.db.models import Value | ||||||
| 
 | 
 | ||||||
| register = template.Library() | register = template.Library() | ||||||
| 
 | 
 | ||||||
|  | @ -99,20 +103,33 @@ def items_purchased(context, category=None): | ||||||
| 
 | 
 | ||||||
|     ''' |     ''' | ||||||
| 
 | 
 | ||||||
|     all_items = commerce.ProductItem.objects.filter( |     in_cart = ( | ||||||
|         cart__user=context.request.user, |         Q(productitem__cart__user=context.request.user) & | ||||||
|         cart__status=commerce.Cart.STATUS_PAID, |         Q(productitem__cart__status=commerce.Cart.STATUS_PAID) | ||||||
|     ).select_related("product", "product__category") |     ) | ||||||
|  | 
 | ||||||
|  |     quantities_in_cart = When( | ||||||
|  |         in_cart, | ||||||
|  |         then="productitem__quantity", | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     quantities_or_zero = Case( | ||||||
|  |         quantities_in_cart, | ||||||
|  |         default=Value(0), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     products = inventory.Product.objects | ||||||
| 
 | 
 | ||||||
|     if category: |     if category: | ||||||
|         all_items = all_items.filter(product__category=category) |         products = products.filter(category=category) | ||||||
|  | 
 | ||||||
|  |     products = products.select_related("category") | ||||||
|  |     products = products.annotate(quantity=Sum(quantities_or_zero)) | ||||||
|  |     products = products.filter(quantity__gt=0) | ||||||
| 
 | 
 | ||||||
|     pq = all_items.values("product").annotate(quantity=Sum("quantity")).all() |  | ||||||
|     products = inventory.Product.objects.all() |  | ||||||
|     out = [] |     out = [] | ||||||
|     for item in pq: |     for prod in products: | ||||||
|         prod = products.get(pk=item["product"]) |         out.append(ProductAndQuantity(prod, prod.quantity)) | ||||||
|         out.append(ProductAndQuantity(prod, item["quantity"])) |  | ||||||
|     return out |     return out | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ class RegistrationCartTestCase(SetTimeMixin, TestCase): | ||||||
|         super(RegistrationCartTestCase, self).setUp() |         super(RegistrationCartTestCase, self).setUp() | ||||||
| 
 | 
 | ||||||
|     def tearDown(self): |     def tearDown(self): | ||||||
|         if False: |         if True: | ||||||
|             # If you're seeing segfaults in tests, enable this. |             # If you're seeing segfaults in tests, enable this. | ||||||
|             call_command( |             call_command( | ||||||
|                 'flush', |                 'flush', | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ from django.core.exceptions import ValidationError | ||||||
| from controller_helpers import TestingCartController | from controller_helpers import TestingCartController | ||||||
| from test_cart import RegistrationCartTestCase | from test_cart import RegistrationCartTestCase | ||||||
| 
 | 
 | ||||||
|  | from registrasion.controllers.discount import DiscountController | ||||||
|  | from registrasion.controllers.product import ProductController | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| 
 | 
 | ||||||
|  | @ -135,6 +137,43 @@ class CeilingsTestCases(RegistrationCartTestCase): | ||||||
|         with self.assertRaises(ValidationError): |         with self.assertRaises(ValidationError): | ||||||
|             first_cart.validate_cart() |             first_cart.validate_cart() | ||||||
| 
 | 
 | ||||||
|  |     def test_discount_ceiling_aggregates_products(self): | ||||||
|  |         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||||
|  |         # after second. | ||||||
|  |         self.make_discount_ceiling( | ||||||
|  |             "Multi-product limit discount ceiling", | ||||||
|  |             limit=2, | ||||||
|  |         ) | ||||||
|  |         for i in xrange(2): | ||||||
|  |             cart = TestingCartController.for_user(self.USER_1) | ||||||
|  |             cart.add_to_cart(self.PROD_1, 1) | ||||||
|  |             cart.next_cart() | ||||||
|  | 
 | ||||||
|  |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [], | ||||||
|  |             [self.PROD_1], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(0, len(discounts)) | ||||||
|  | 
 | ||||||
|  |     def test_flag_ceiling_aggregates_products(self): | ||||||
|  |         # Create two carts, add 1xprod_1 to each. Ceiling should disappear | ||||||
|  |         # after second. | ||||||
|  |         self.make_ceiling("Multi-product limit ceiling", limit=2) | ||||||
|  | 
 | ||||||
|  |         for i in xrange(2): | ||||||
|  |             cart = TestingCartController.for_user(self.USER_1) | ||||||
|  |             cart.add_to_cart(self.PROD_1, 1) | ||||||
|  |             cart.next_cart() | ||||||
|  | 
 | ||||||
|  |         products = ProductController.available_products( | ||||||
|  |             self.USER_1, | ||||||
|  |             products=[self.PROD_1], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(0, len(products)) | ||||||
|  | 
 | ||||||
|     def test_items_released_from_ceiling_by_refund(self): |     def test_items_released_from_ceiling_by_refund(self): | ||||||
|         self.make_ceiling("Limit ceiling", limit=1) |         self.make_ceiling("Limit ceiling", limit=1) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ from decimal import Decimal | ||||||
| 
 | 
 | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import conditions | from registrasion.models import conditions | ||||||
| from registrasion.controllers import discount | from registrasion.controllers.discount import DiscountController | ||||||
| from controller_helpers import TestingCartController | from controller_helpers import TestingCartController | ||||||
| 
 | 
 | ||||||
| from test_cart import RegistrationCartTestCase | from test_cart import RegistrationCartTestCase | ||||||
|  | @ -243,22 +243,30 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|             # The discount is applied. |             # The discount is applied. | ||||||
|             self.assertEqual(1, len(discount_items)) |             self.assertEqual(1, len(discount_items)) | ||||||
| 
 | 
 | ||||||
|     # Tests for the discount.available_discounts enumerator |     # Tests for the DiscountController.available_discounts enumerator | ||||||
|     def test_enumerate_no_discounts_for_no_input(self): |     def test_enumerate_no_discounts_for_no_input(self): | ||||||
|         discounts = discount.available_discounts(self.USER_1, [], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(0, len(discounts)) |         self.assertEqual(0, len(discounts)) | ||||||
| 
 | 
 | ||||||
|     def test_enumerate_no_discounts_if_condition_not_met(self): |     def test_enumerate_no_discounts_if_condition_not_met(self): | ||||||
|         self.add_discount_prod_1_includes_cat_2(quantity=1) |         self.add_discount_prod_1_includes_cat_2(quantity=1) | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_3], |             [self.PROD_3], | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(0, len(discounts)) |         self.assertEqual(0, len(discounts)) | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_2], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(0, len(discounts)) |         self.assertEqual(0, len(discounts)) | ||||||
| 
 | 
 | ||||||
|     def test_category_discount_appears_once_if_met_twice(self): |     def test_category_discount_appears_once_if_met_twice(self): | ||||||
|  | @ -267,7 +275,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [self.CAT_2], |             [self.CAT_2], | ||||||
|             [self.PROD_3], |             [self.PROD_3], | ||||||
|  | @ -280,7 +288,11 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_2], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(1, len(discounts)) |         self.assertEqual(1, len(discounts)) | ||||||
| 
 | 
 | ||||||
|     def test_category_discount_appears_with_product(self): |     def test_category_discount_appears_with_product(self): | ||||||
|  | @ -289,7 +301,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_3], |             [self.PROD_3], | ||||||
|  | @ -302,7 +314,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_3, self.PROD_4] |             [self.PROD_3, self.PROD_4] | ||||||
|  | @ -315,7 +327,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_2], |             [self.PROD_2], | ||||||
|  | @ -328,7 +340,11 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_1], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_1], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(0, len(discounts)) |         self.assertEqual(0, len(discounts)) | ||||||
| 
 | 
 | ||||||
|     def test_discount_quantity_is_correct_before_first_purchase(self): |     def test_discount_quantity_is_correct_before_first_purchase(self): | ||||||
|  | @ -338,7 +354,11 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
|         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity |         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_2], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(2, discounts[0].quantity) |         self.assertEqual(2, discounts[0].quantity) | ||||||
| 
 | 
 | ||||||
|         cart.next_cart() |         cart.next_cart() | ||||||
|  | @ -349,32 +369,63 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity |         cart.add_to_cart(self.PROD_3, 1)  # Exhaust the quantity | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_2], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(1, discounts[0].quantity) |         self.assertEqual(1, discounts[0].quantity) | ||||||
| 
 | 
 | ||||||
|         cart.next_cart() |         cart.next_cart() | ||||||
| 
 | 
 | ||||||
|     def test_discount_is_gone_after_quantity_exhausted(self): |     def test_discount_is_gone_after_quantity_exhausted(self): | ||||||
|         self.test_discount_quantity_is_correct_after_first_purchase() |         self.test_discount_quantity_is_correct_after_first_purchase() | ||||||
|         discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [self.CAT_2], | ||||||
|  |             [], | ||||||
|  |         ) | ||||||
|         self.assertEqual(0, len(discounts)) |         self.assertEqual(0, len(discounts)) | ||||||
| 
 | 
 | ||||||
|     def test_product_discount_enabled_twice_appears_twice(self): |     def test_product_discount_enabled_twice_appears_twice(self): | ||||||
|         self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) |         self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_3, self.PROD_4], |             [self.PROD_3, self.PROD_4], | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(2, len(discounts)) |         self.assertEqual(2, len(discounts)) | ||||||
| 
 | 
 | ||||||
|  |     def test_product_discount_applied_on_different_invoices(self): | ||||||
|  |         # quantity=1 means "quantity per product" | ||||||
|  |         self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1) | ||||||
|  |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|  |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
|  |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [], | ||||||
|  |             [self.PROD_3, self.PROD_4], | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(2, len(discounts)) | ||||||
|  |         # adding one of PROD_3 should make it no longer an available discount. | ||||||
|  |         cart.add_to_cart(self.PROD_3, 1) | ||||||
|  |         cart.next_cart() | ||||||
|  | 
 | ||||||
|  |         # should still have (and only have) the discount for prod_4 | ||||||
|  |         discounts = DiscountController.available_discounts( | ||||||
|  |             self.USER_1, | ||||||
|  |             [], | ||||||
|  |             [self.PROD_3, self.PROD_4], | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(1, len(discounts)) | ||||||
|  | 
 | ||||||
|     def test_discounts_are_released_by_refunds(self): |     def test_discounts_are_released_by_refunds(self): | ||||||
|         self.add_discount_prod_1_includes_prod_2(quantity=2) |         self.add_discount_prod_1_includes_prod_2(quantity=2) | ||||||
|         cart = TestingCartController.for_user(self.USER_1) |         cart = TestingCartController.for_user(self.USER_1) | ||||||
|         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount |         cart.add_to_cart(self.PROD_1, 1)  # Enable the discount | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_2], |             [self.PROD_2], | ||||||
|  | @ -388,7 +439,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
| 
 | 
 | ||||||
|         cart.next_cart() |         cart.next_cart() | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_2], |             [self.PROD_2], | ||||||
|  | @ -398,7 +449,7 @@ class DiscountTestCase(RegistrationCartTestCase): | ||||||
|         cart.cart.status = commerce.Cart.STATUS_RELEASED |         cart.cart.status = commerce.Cart.STATUS_RELEASED | ||||||
|         cart.cart.save() |         cart.cart.save() | ||||||
| 
 | 
 | ||||||
|         discounts = discount.available_discounts( |         discounts = DiscountController.available_discounts( | ||||||
|             self.USER_1, |             self.USER_1, | ||||||
|             [], |             [], | ||||||
|             [self.PROD_2], |             [self.PROD_2], | ||||||
|  |  | ||||||
|  | @ -25,3 +25,33 @@ def all_arguments_optional(ntcls): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     return ntcls |     return ntcls | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def lazy(function, *args, **kwargs): | ||||||
|  |     ''' Produces a callable so that functions can be lazily evaluated in | ||||||
|  |     templates. | ||||||
|  | 
 | ||||||
|  |     Arguments: | ||||||
|  | 
 | ||||||
|  |         function (callable): The function to call at evaluation time. | ||||||
|  | 
 | ||||||
|  |         args: Positional arguments, passed directly to ``function``. | ||||||
|  | 
 | ||||||
|  |         kwargs: Keyword arguments, passed directly to ``function``. | ||||||
|  | 
 | ||||||
|  |     Return: | ||||||
|  | 
 | ||||||
|  |         callable: A callable that will evaluate a call to ``function`` with | ||||||
|  |             the specified arguments. | ||||||
|  | 
 | ||||||
|  |     ''' | ||||||
|  | 
 | ||||||
|  |     NOT_EVALUATED = object() | ||||||
|  |     retval = [NOT_EVALUATED] | ||||||
|  | 
 | ||||||
|  |     def evaluate(): | ||||||
|  |         if retval[0] is NOT_EVALUATED: | ||||||
|  |             retval[0] = function(*args, **kwargs) | ||||||
|  |         return retval[0] | ||||||
|  | 
 | ||||||
|  |     return evaluate | ||||||
|  |  | ||||||
|  | @ -5,7 +5,7 @@ from registrasion import util | ||||||
| from registrasion.models import commerce | from registrasion.models import commerce | ||||||
| from registrasion.models import inventory | from registrasion.models import inventory | ||||||
| from registrasion.models import people | from registrasion.models import people | ||||||
| from registrasion.controllers import discount | from registrasion.controllers.discount import DiscountController | ||||||
| from registrasion.controllers.cart import CartController | from registrasion.controllers.cart import CartController | ||||||
| from registrasion.controllers.credit_note import CreditNoteController | from registrasion.controllers.credit_note import CreditNoteController | ||||||
| from registrasion.controllers.invoice import InvoiceController | from registrasion.controllers.invoice import InvoiceController | ||||||
|  | @ -181,33 +181,35 @@ def guided_registration(request): | ||||||
|             attendee.save() |             attendee.save() | ||||||
|             return next_step |             return next_step | ||||||
| 
 | 
 | ||||||
|         for category in cats: |         with CartController.operations_batch(request.user): | ||||||
|             products = [ |             for category in cats: | ||||||
|                 i for i in available_products |                 products = [ | ||||||
|                 if i.category == category |                     i for i in available_products | ||||||
|             ] |                     if i.category == category | ||||||
|  |                 ] | ||||||
| 
 | 
 | ||||||
|             prefix = "category_" + str(category.id) |                 prefix = "category_" + str(category.id) | ||||||
|             p = _handle_products(request, category, products, prefix) |                 p = _handle_products(request, category, products, prefix) | ||||||
|             products_form, discounts, products_handled = p |                 products_form, discounts, products_handled = p | ||||||
| 
 | 
 | ||||||
|             section = GuidedRegistrationSection( |                 section = GuidedRegistrationSection( | ||||||
|                 title=category.name, |                     title=category.name, | ||||||
|                 description=category.description, |                     description=category.description, | ||||||
|                 discounts=discounts, |                     discounts=discounts, | ||||||
|                 form=products_form, |                     form=products_form, | ||||||
|             ) |                 ) | ||||||
| 
 | 
 | ||||||
|             if products: |                 if products: | ||||||
|                 # This product category has items to show. |                     # This product category has items to show. | ||||||
|                 sections.append(section) |                     sections.append(section) | ||||||
|                 # Add this to the list of things to show if the form errors. |                     # Add this to the list of things to show if the form | ||||||
|                 request.session[SESSION_KEY].append(category.id) |                     # errors. | ||||||
|  |                     request.session[SESSION_KEY].append(category.id) | ||||||
| 
 | 
 | ||||||
|                 if request.method == "POST" and not products_form.errors: |                     if request.method == "POST" and not products_form.errors: | ||||||
|                     # This is only saved if we pass each form with no errors, |                         # This is only saved if we pass each form with no | ||||||
|                     # and if the form actually has products. |                         # errors, and if the form actually has products. | ||||||
|                     attendee.guided_categories_complete.add(category) |                         attendee.guided_categories_complete.add(category) | ||||||
| 
 | 
 | ||||||
|     if sections and request.method == "POST": |     if sections and request.method == "POST": | ||||||
|         for section in sections: |         for section in sections: | ||||||
|  | @ -427,7 +429,15 @@ def _handle_products(request, category, products, prefix): | ||||||
|                 ) |                 ) | ||||||
|     handled = False if products_form.errors else True |     handled = False if products_form.errors else True | ||||||
| 
 | 
 | ||||||
|     discounts = discount.available_discounts(request.user, [], products) |     # Making this a function to lazily evaluate when it's displayed | ||||||
|  |     # in templates. | ||||||
|  | 
 | ||||||
|  |     discounts = util.lazy( | ||||||
|  |         DiscountController.available_discounts, | ||||||
|  |         request.user, | ||||||
|  |         [], | ||||||
|  |         products, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
|     return products_form, discounts, handled |     return products_form, discounts, handled | ||||||
| 
 | 
 | ||||||
|  | @ -435,14 +445,14 @@ def _handle_products(request, category, products, prefix): | ||||||
| def _set_quantities_from_products_form(products_form, current_cart): | def _set_quantities_from_products_form(products_form, current_cart): | ||||||
| 
 | 
 | ||||||
|     quantities = list(products_form.product_quantities()) |     quantities = list(products_form.product_quantities()) | ||||||
| 
 |     id_to_quantity = dict(i[:2] for i in quantities) | ||||||
|     pks = [i[0] for i in quantities] |     pks = [i[0] for i in quantities] | ||||||
|     products = inventory.Product.objects.filter( |     products = inventory.Product.objects.filter( | ||||||
|         id__in=pks, |         id__in=pks, | ||||||
|     ).select_related("category") |     ).select_related("category") | ||||||
| 
 | 
 | ||||||
|     product_quantities = [ |     product_quantities = [ | ||||||
|         (products.get(pk=i[0]), i[1]) for i in quantities |         (product, id_to_quantity[product.id]) for product in products | ||||||
|     ] |     ] | ||||||
|     field_names = dict( |     field_names = dict( | ||||||
|         (i[0][0], i[1][2]) for i in zip(product_quantities, quantities) |         (i[0][0], i[1][2]) for i in zip(product_quantities, quantities) | ||||||
|  |  | ||||||
|  | @ -1,2 +1,2 @@ | ||||||
| [flake8] | [flake8] | ||||||
| exclude = registrasion/migrations/*, build/*, docs/* | exclude = registrasion/migrations/*, build/*, docs/*, dist/* | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 Christopher Neugebauer
						Christopher Neugebauer