From 0182a32f03475d673d745ec7fad9d10c41a847fc Mon Sep 17 00:00:00 2001 From: Christopher Neugebauer Date: Fri, 4 Mar 2016 13:07:45 -0800 Subject: [PATCH] Fixes various errors in discount calculation, and adds tests for these --- registrasion/controllers/cart.py | 9 +++++--- registrasion/tests/test_cart.py | 13 ++++++++++- registrasion/tests/test_discount.py | 35 +++++++++++++++++++++++++++-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index acdce035..5cbe5c8c 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -183,7 +183,12 @@ class CartController(object): # Delete the existing entries. rego.DiscountItem.objects.filter(cart=self.cart).delete() - for item in self.cart.productitem_set.all(): + # The highest-value discounts will apply to the highest-value + # products first. + product_items = self.cart.productitem_set.all() + product_items = product_items.order_by('product__price') + product_items = reversed(product_items) + for item in product_items: self._add_discount(item.product, item.quantity) def _add_discount(self, product, quantity): @@ -202,9 +207,7 @@ class CartController(object): # Get the count of past uses of this discount condition # as this affects the total amount we're allowed to use now. past_uses = rego.DiscountItem.objects.filter( - cart__active=False, discount=discount.discount, - product=product, ) agg = past_uses.aggregate(Sum("quantity")) past_uses = agg["quantity__sum"] diff --git a/registrasion/tests/test_cart.py b/registrasion/tests/test_cart.py index dad32efa..c65011b9 100644 --- a/registrasion/tests/test_cart.py +++ b/registrasion/tests/test_cart.py @@ -82,7 +82,18 @@ class RegistrationCartTestCase(SetTimeMixin, TestCase): limit_per_user=10, order=10, ) - cls.PROD_2.save() + cls.PROD_3.save() + + cls.PROD_4 = rego.Product.objects.create( + name="Product 4", + description="This is a test product. It costs $5. " + "A user may have 10 of them.", + category=cls.CAT_2, + price=Decimal("5.00"), + limit_per_user=10, + order=10, + ) + cls.PROD_4.save() @classmethod def make_ceiling(cls, name, limit=None, start_time=None, end_time=None): diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index c0709e8e..5d325eff 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -29,7 +29,10 @@ class DiscountTestCase(RegistrationCartTestCase): return discount @classmethod - def add_discount_prod_1_includes_cat_2(cls, amount=Decimal(100)): + def add_discount_prod_1_includes_cat_2( + cls, + amount=Decimal(100), + quantity=2): discount = rego.IncludedProductDiscount.objects.create( description="PROD_1 includes CAT_2 " + str(amount) + "%", ) @@ -40,7 +43,7 @@ class DiscountTestCase(RegistrationCartTestCase): discount=discount, category=cls.CAT_2, percentage=amount, - quantity=2 + quantity=quantity, ).save() return discount @@ -169,3 +172,31 @@ class DiscountTestCase(RegistrationCartTestCase): discount_items = list(cart.cart.discountitem_set.all()) self.assertEqual(2, discount_items[0].quantity) + + def test_category_discount_applies_once_per_category(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + + # Add two items from category 2 + cart.add_to_cart(self.PROD_3, 1) + cart.add_to_cart(self.PROD_4, 1) + + discount_items = list(cart.cart.discountitem_set.all()) + # There is one discount, and it should apply to one item. + self.assertEqual(1, len(discount_items)) + self.assertEqual(1, discount_items[0].quantity) + + def test_category_discount_applies_to_highest_value(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) + + # Add two items from category 2, add the less expensive one first + cart.add_to_cart(self.PROD_4, 1) + cart.add_to_cart(self.PROD_3, 1) + + discount_items = list(cart.cart.discountitem_set.all()) + # There is one discount, and it should apply to the more expensive. + self.assertEqual(1, len(discount_items)) + self.assertEqual(self.PROD_3, discount_items[0].product)