diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index ad4458ea..204dc9f6 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -399,9 +399,11 @@ class CartController(object): # Delete the existing entries. commerce.DiscountItem.objects.filter(cart=self.cart).delete() + # Order the products such that the most expensive ones are + # processed first. product_items = self.cart.productitem_set.all().select_related( "product", "product__category", "product__price" - ) + ).order_by("-product__price") products = [i.product for i in product_items] discounts = DiscountController.available_discounts( @@ -411,8 +413,7 @@ class CartController(object): ) # The highest-value discounts will apply to the highest-value - # products first. - product_items = reversed(product_items) + # products first, because of the order_by clause for item in product_items: self._add_discount(item.product, item.quantity, discounts) diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index d7920a10..2696535b 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -243,6 +243,29 @@ class DiscountTestCase(RegistrationCartTestCase): # The discount is applied. self.assertEqual(1, len(discount_items)) + def test_discount_applies_to_most_expensive_item(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + import itertools + prods = (self.PROD_3, self.PROD_4) + for first, second in itertools.permutations(prods, 2): + + cart.set_quantity(first, 1) + cart.set_quantity(second, 1) + + # There should only be one discount + discount_items = list(cart.cart.discountitem_set.all()) + self.assertEqual(1, len(discount_items)) + + # It should always apply to PROD_3, as it costs more. + self.assertEqual(discount_items[0].product, self.PROD_3) + + cart.set_quantity(first, 0) + cart.set_quantity(second, 0) + # Tests for the DiscountController.available_discounts enumerator def test_enumerate_no_discounts_for_no_input(self): discounts = DiscountController.available_discounts(