diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py index 1c7fa59f..164d95cc 100644 --- a/registrasion/controllers/discount.py +++ b/registrasion/controllers/discount.py @@ -5,7 +5,7 @@ from registrasion.models import commerce from registrasion.models import conditions from django.db.models import Case -from django.db.models import Q +from django.db.models import F, Q from django.db.models import Sum from django.db.models import Value from django.db.models import When @@ -64,9 +64,7 @@ class DiscountController(object): discount = clause.discount cond = ConditionController.for_condition(discount) - past_use_count = discount.past_use_count - - + past_use_count = clause.past_use_count if past_use_count >= clause.quantity: # This clause has exceeded its use count pass @@ -139,7 +137,6 @@ class DiscountController(object): discounts = discounttype.objects.filter(id__in=valid_discounts) ctrl = ConditionController.for_type(discounttype) discounts = ctrl.pre_filter(discounts, user) - discounts = cls._annotate_with_past_uses(discounts, user) all_subsets.append(discounts) filtered_discounts = list(itertools.chain(*all_subsets)) @@ -148,11 +145,17 @@ class DiscountController(object): # (contains annotations needed in the future) from_filter = dict((i.id, i) for i in filtered_discounts) - # The set of all potential discounts - discount_clauses = set(itertools.chain( + clause_sets = ( product_discounts.filter(discount__in=filtered_discounts), all_category_discounts.filter(discount__in=filtered_discounts), - )) + ) + + clause_sets = ( + cls._annotate_with_past_uses(i, user) for i in clause_sets + ) + + # The set of all potential discount clauses + discount_clauses = set(itertools.chain(*clause_sets)) # Replace discounts with the filtered ones # These are the correct subclasses (saves query later on), and have @@ -164,15 +167,26 @@ class DiscountController(object): @classmethod def _annotate_with_past_uses(cls, queryset, user): - ''' Annotates the queryset with a usage count for that discount by the - given user. ''' + ''' Annotates the queryset with a usage count for that discount claus + by the given user. ''' + + if queryset.model == conditions.DiscountForCategory: + matches = ( + Q(category=F('discount__discountitem__product__category')) + ) + elif queryset.model == conditions.DiscountForProduct: + matches = ( + Q(product=F('discount__discountitem__product')) + ) + + in_carts = ( + Q(discount__discountitem__cart__user=user) & + Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID) + ) past_use_quantity = When( - ( - Q(discountitem__cart__user=user) & - Q(discountitem__cart__status=commerce.Cart.STATUS_PAID) - ), - then="discountitem__quantity", + in_carts & matches, + then="discount__discountitem__quantity", ) past_use_quantity_or_zero = Case( diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index 4b92c81b..d7920a10 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -398,6 +398,29 @@ class DiscountTestCase(RegistrationCartTestCase): ) self.assertEqual(2, len(discounts)) + def test_product_discount_applied_on_different_invoices(self): + # quantity=1 means "quantity per product" + self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1) + cart = TestingCartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4], + ) + self.assertEqual(2, len(discounts)) + # adding one of PROD_3 should make it no longer an available discount. + cart.add_to_cart(self.PROD_3, 1) + cart.next_cart() + + # should still have (and only have) the discount for prod_4 + discounts = DiscountController.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4], + ) + self.assertEqual(1, len(discounts)) + def test_discounts_are_released_by_refunds(self): self.add_discount_prod_1_includes_prod_2(quantity=2) cart = TestingCartController.for_user(self.USER_1)