diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index d9307b39..ad680225 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -1,8 +1,9 @@ import datetime +import discount from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError -from django.db.models import Max, Sum +from django.db.models import Max from django.utils import timezone from registrasion import models as rego @@ -187,38 +188,47 @@ class CartController(object): # Delete the existing entries. rego.DiscountItem.objects.filter(cart=self.cart).delete() + product_items = self.cart.productitem_set.all() + + products = [i.product for i in product_items] + discounts = discount.available_discounts(self.cart.user, [], products) + # The highest-value discounts will apply to the highest-value # products first. product_items = self.cart.productitem_set.all() product_items = product_items.order_by('product__price') product_items = reversed(product_items) for item in product_items: - self._add_discount(item.product, item.quantity) + self._add_discount(item.product, item.quantity, discounts) - def _add_discount(self, product, quantity): - ''' Calculates the best available discounts for this product. - NB this will be super-inefficient in aggregate because discounts will - be re-tested for each product. We should work on that.''' + def _add_discount(self, product, quantity, discounts): + ''' Applies the best discounts on the given product, from the given + discounts.''' - prod = ProductController(product) - discounts = prod.available_discounts(self.cart.user) - discounts.sort(key=lambda discount: discount.value) + def matches(discount): + ''' Returns True if and only if the given discount apples to + our product. ''' + if isinstance(discount.clause, rego.DiscountForCategory): + return discount.clause.category == product.category + else: + return discount.clause.product == product - for discount in reversed(discounts): + def value(discount): + ''' Returns the value of this discount clause + as applied to this product ''' + if discount.clause.percentage is not None: + return discount.clause.percentage * product.price + else: + return discount.clause.price + + discounts = [i for i in discounts if matches(i)] + discounts.sort(key=value) + + for candidate in reversed(discounts): if quantity == 0: break - - # Get the count of past uses of this discount condition - # as this affects the total amount we're allowed to use now. - past_uses = rego.DiscountItem.objects.filter( - cart__user=self.cart.user, - discount=discount.discount, - ) - agg = past_uses.aggregate(Sum("quantity")) - past_uses = agg["quantity__sum"] - if past_uses is None: - past_uses = 0 - if past_uses == discount.condition.quantity: + elif candidate.quantity == 0: + # This discount clause has been exhausted by this cart continue # Get a provisional instance for this DiscountItem @@ -226,13 +236,13 @@ class CartController(object): discount_item = rego.DiscountItem.objects.create( product=product, cart=self.cart, - discount=discount.discount, + discount=candidate.discount, quantity=quantity, ) # Truncate the quantity for this DiscountItem if we exceed quantity ours = discount_item.quantity - allowed = discount.condition.quantity - past_uses + allowed = candidate.quantity if ours > allowed: discount_item.quantity = allowed # Update the remaining quantity. @@ -240,4 +250,6 @@ class CartController(object): else: quantity = 0 + candidate.quantity -= discount_item.quantity + discount_item.save() diff --git a/registrasion/controllers/discount.py b/registrasion/controllers/discount.py new file mode 100644 index 00000000..dcb83dfd --- /dev/null +++ b/registrasion/controllers/discount.py @@ -0,0 +1,83 @@ +import itertools + +from conditions import ConditionController +from registrasion import models as rego + +from django.db.models import Sum + + +class DiscountAndQuantity(object): + def __init__(self, discount, clause, quantity): + self.discount = discount + self.clause = clause + self.quantity = quantity + + +def available_discounts(user, categories, products): + ''' Returns all discounts available to this user for the given categories + and products. The discounts also list the available quantity for this user, + not including products that are pending purchase. ''' + + # discounts that match provided categories + category_discounts = rego.DiscountForCategory.objects.filter( + category__in=categories + ) + # discounts that match provided products + product_discounts = rego.DiscountForProduct.objects.filter( + product__in=products + ) + # discounts that match categories for provided products + product_category_discounts = rego.DiscountForCategory.objects.filter( + category__in=(product.category for product in products) + ) + # (Not relevant: discounts that match products in provided categories) + + # The set of all potential discounts + potential_discounts = set(itertools.chain( + product_discounts, + category_discounts, + product_category_discounts, + )) + + discounts = [] + + # Markers so that we don't need to evaluate given conditions more than once + accepted_discounts = set() + failed_discounts = set() + + for discount in potential_discounts: + real_discount = rego.DiscountBase.objects.get_subclass( + pk=discount.discount.pk, + ) + cond = ConditionController.for_condition(real_discount) + + # Count the past uses of the given discount item. + # If this user has exceeded the limit for the clause, this clause + # is not available any more. + past_uses = rego.DiscountItem.objects.filter( + cart__user=user, + cart__active=False, # Only past carts count + discount=discount.discount, + ) + agg = past_uses.aggregate(Sum("quantity")) + past_use_count = agg["quantity__sum"] + if past_use_count is None: + past_use_count = 0 + + if past_use_count >= discount.quantity: + # This clause has exceeded its use count + pass + elif real_discount not in failed_discounts: + # This clause is still available + if real_discount in accepted_discounts or cond.is_met(user, 0): + # This clause is valid for this user + discounts.append(DiscountAndQuantity( + discount=real_discount, + clause=discount, + quantity=discount.quantity - past_use_count, + )) + accepted_discounts.add(real_discount) + else: + # This clause is not valid for this user + failed_discounts.add(real_discount) + return discounts diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 8a1f402e..2d0f2963 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -1,24 +1,39 @@ import itertools -from collections import namedtuple - from django.db.models import Q from registrasion import models as rego from conditions import ConditionController -DiscountEnabler = namedtuple( - "DiscountEnabler", ( - "discount", - "condition", - "value")) - class ProductController(object): def __init__(self, product): self.product = product + @classmethod + def available_products(cls, user, category=None, products=None): + ''' Returns a list of all of the products that are available per + enabling conditions from the given categories. + TODO: refactor so that all conditions are tested here and + can_add_with_enabling_conditions calls this method. ''' + if category is None and products is None: + raise ValueError("You must provide products or a category") + + if category is not None: + all_products = rego.Product.objects.filter(category=category) + else: + all_products = [] + + if products is not None: + all_products = itertools.chain(all_products, products) + + return [ + product + for product in all_products + if cls(product).can_add_with_enabling_conditions(user, 0) + ] + def user_can_add_within_limit(self, user, quantity): ''' Return true if the user is able to add _quantity_ to their count of this Product without exceeding _limit_per_user_.''' @@ -68,39 +83,3 @@ class ProductController(object): return False return True - - def get_enabler(self, condition): - if condition.percentage is not None: - value = condition.percentage * self.product.price - else: - value = condition.price - return DiscountEnabler( - discount=condition.discount, - condition=condition, - value=value - ) - - def available_discounts(self, user): - ''' Returns the set of available discounts for this user, for this - product. ''' - - product_discounts = rego.DiscountForProduct.objects.filter( - product=self.product) - category_discounts = rego.DiscountForCategory.objects.filter( - category=self.product.category - ) - - potential_discounts = set(itertools.chain( - (self.get_enabler(i) for i in product_discounts), - (self.get_enabler(i) for i in category_discounts), - )) - - discounts = [] - for discount in potential_discounts: - real_discount = rego.DiscountBase.objects.get_subclass( - pk=discount.discount.pk) - cond = ConditionController.for_condition(real_discount) - if cond.is_met(user, 0): - discounts.append(discount) - - return discounts diff --git a/registrasion/forms.py b/registrasion/forms.py index 79e6d95a..fd0359bb 100644 --- a/registrasion/forms.py +++ b/registrasion/forms.py @@ -1,21 +1,26 @@ import models as rego -from controllers.product import ProductController - from django import forms -def CategoryForm(category): +def ProductsForm(products): PREFIX = "product_" def field_name(product): return PREFIX + ("%d" % product.id) - class _CategoryForm(forms.Form): + class _ProductsForm(forms.Form): - @staticmethod - def initial_data(product_quantities): + def __init__(self, *a, **k): + if "product_quantities" in k: + initial = _ProductsForm.initial_data(k["product_quantities"]) + k["initial"] = initial + del k["product_quantities"] + super(_ProductsForm, self).__init__(*a, **k) + + @classmethod + def initial_data(cls, product_quantities): ''' Prepares initial data for an instance of this form. product_quantities is a sequence of (product,quantity) tuples ''' initial = {} @@ -32,18 +37,6 @@ def CategoryForm(category): product_id = int(name[len(PREFIX):]) yield (product_id, value, name) - def disable_product(self, product): - ''' Removes a given product from this form. ''' - del self.fields[field_name(product)] - - def disable_products_for_user(self, user): - for product in products: - # Remove fields that do not have an enabling condition. - prod = ProductController(product) - if not prod.can_add_with_enabling_conditions(user, 0): - self.disable_product(product) - - products = rego.Product.objects.filter(category=category).order_by("order") for product in products: help_text = "$%d -- %s" % (product.price, product.description) @@ -52,9 +45,9 @@ def CategoryForm(category): label=product.name, help_text=help_text, ) - _CategoryForm.base_fields[field_name(product)] = field + _ProductsForm.base_fields[field_name(product)] = field - return _CategoryForm + return _ProductsForm class ProfileForm(forms.ModelForm): diff --git a/registrasion/templates/product_category.html b/registrasion/templates/product_category.html index 0e567bf0..fb54a58d 100644 --- a/registrasion/templates/product_category.html +++ b/registrasion/templates/product_category.html @@ -5,8 +5,6 @@
{{ category.description }}
- diff --git a/registrasion/templatetags/registrasion_tags.py b/registrasion/templatetags/registrasion_tags.py index 4357d16a..a6741c0e 100644 --- a/registrasion/templatetags/registrasion_tags.py +++ b/registrasion/templatetags/registrasion_tags.py @@ -1,11 +1,47 @@ from registrasion import models as rego +from collections import namedtuple from django import template +from django.db.models import Sum register = template.Library() +ProductAndQuantity = namedtuple("ProductAndQuantity", ["product", "quantity"]) @register.assignment_tag(takes_context=True) def available_categories(context): ''' Returns all of the available product categories ''' return rego.Category.objects.all() + +@register.assignment_tag(takes_context=True) +def invoices(context): + ''' Returns all of the invoices that this user has. ''' + return rego.Invoice.objects.filter(cart__user=context.request.user) + +@register.assignment_tag(takes_context=True) +def items_pending(context): + ''' Returns all of the items that this user has in their current cart, + and is awaiting payment. ''' + + all_items = rego.ProductItem.objects.filter( + cart__user=context.request.user, + cart__active=True, + ) + return all_items + +@register.assignment_tag(takes_context=True) +def items_purchased(context): + ''' Returns all of the items that this user has purchased ''' + + all_items = rego.ProductItem.objects.filter( + cart__user=context.request.user, + cart__active=False, + ) + + products = set(item.product for item in all_items) + out = [] + for product in products: + pp = all_items.filter(product=product) + quantity = pp.aggregate(Sum("quantity"))["quantity__sum"] + out.append(ProductAndQuantity(product, quantity)) + return out diff --git a/registrasion/tests/test_discount.py b/registrasion/tests/test_discount.py index bb1c4bfe..222afc09 100644 --- a/registrasion/tests/test_discount.py +++ b/registrasion/tests/test_discount.py @@ -3,7 +3,9 @@ import pytz from decimal import Decimal from registrasion import models as rego +from registrasion.controllers import discount from registrasion.controllers.cart import CartController +from registrasion.controllers.invoice import InvoiceController from test_cart import RegistrationCartTestCase @@ -13,7 +15,11 @@ UTC = pytz.timezone('UTC') class DiscountTestCase(RegistrationCartTestCase): @classmethod - def add_discount_prod_1_includes_prod_2(cls, amount=Decimal(100)): + def add_discount_prod_1_includes_prod_2( + cls, + amount=Decimal(100), + quantity=2, + ): discount = rego.IncludedProductDiscount.objects.create( description="PROD_1 includes PROD_2 " + str(amount) + "%", ) @@ -24,7 +30,7 @@ class DiscountTestCase(RegistrationCartTestCase): discount=discount, product=cls.PROD_2, percentage=amount, - quantity=2 + quantity=quantity, ).save() return discount @@ -32,7 +38,8 @@ class DiscountTestCase(RegistrationCartTestCase): def add_discount_prod_1_includes_cat_2( cls, amount=Decimal(100), - quantity=2): + quantity=2, + ): discount = rego.IncludedProductDiscount.objects.create( description="PROD_1 includes CAT_2 " + str(amount) + "%", ) @@ -47,6 +54,33 @@ class DiscountTestCase(RegistrationCartTestCase): ).save() return discount + @classmethod + def add_discount_prod_1_includes_prod_3_and_prod_4( + cls, + amount=Decimal(100), + quantity=2, + ): + discount = rego.IncludedProductDiscount.objects.create( + description="PROD_1 includes PROD_3 and PROD_4 " + + str(amount) + "%", + ) + discount.save() + discount.enabling_products.add(cls.PROD_1) + discount.save() + rego.DiscountForProduct.objects.create( + discount=discount, + product=cls.PROD_3, + percentage=amount, + quantity=quantity, + ).save() + rego.DiscountForProduct.objects.create( + discount=discount, + product=cls.PROD_4, + percentage=amount, + quantity=quantity, + ).save() + return discount + def test_discount_is_applied(self): self.add_discount_prod_1_includes_prod_2() @@ -214,3 +248,132 @@ class DiscountTestCase(RegistrationCartTestCase): discount_items = list(cart.cart.discountitem_set.all()) # The discount is applied. self.assertEqual(1, len(discount_items)) + + # Tests for the discount.available_discounts enumerator + def test_enumerate_no_discounts_for_no_input(self): + discounts = discount.available_discounts(self.USER_1, [], []) + self.assertEqual(0, len(discounts)) + + def test_enumerate_no_discounts_if_condition_not_met(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + discounts = discount.available_discounts( + self.USER_1, + [], + [self.PROD_3], + ) + self.assertEqual(0, len(discounts)) + + discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + self.assertEqual(0, len(discounts)) + + def test_category_discount_appears_once_if_met_twice(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts( + self.USER_1, + [self.CAT_2], + [self.PROD_3], + ) + self.assertEqual(1, len(discounts)) + + def test_category_discount_appears_with_category(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + self.assertEqual(1, len(discounts)) + + def test_category_discount_appears_with_product(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts( + self.USER_1, + [], + [self.PROD_3], + ) + self.assertEqual(1, len(discounts)) + + def test_category_discount_appears_once_with_two_valid_product(self): + self.add_discount_prod_1_includes_cat_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4] + ) + self.assertEqual(1, len(discounts)) + + def test_product_discount_appears_with_product(self): + self.add_discount_prod_1_includes_prod_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts( + self.USER_1, + [], + [self.PROD_2], + ) + self.assertEqual(1, len(discounts)) + + def test_product_discount_does_not_appear_with_category(self): + self.add_discount_prod_1_includes_prod_2(quantity=1) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + + discounts = discount.available_discounts(self.USER_1, [self.CAT_1], []) + self.assertEqual(0, len(discounts)) + + def test_discount_quantity_is_correct_before_first_purchase(self): + self.add_discount_prod_1_includes_cat_2(quantity=2) + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity + + discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + self.assertEqual(2, discounts[0].quantity) + inv = InvoiceController.for_cart(cart.cart) + inv.pay("Dummy reference", inv.invoice.value) + self.assertTrue(inv.invoice.paid) + + def test_discount_quantity_is_correct_after_first_purchase(self): + self.test_discount_quantity_is_correct_before_first_purchase() + + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity + + discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + self.assertEqual(1, discounts[0].quantity) + inv = InvoiceController.for_cart(cart.cart) + inv.pay("Dummy reference", inv.invoice.value) + self.assertTrue(inv.invoice.paid) + + def test_discount_is_gone_after_quantity_exhausted(self): + self.test_discount_quantity_is_correct_after_first_purchase() + discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) + self.assertEqual(0, len(discounts)) + + def test_product_discount_enabled_twice_appears_twice(self): + self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) + cart = CartController.for_user(self.USER_1) + cart.add_to_cart(self.PROD_1, 1) # Enable the discount + discounts = discount.available_discounts( + self.USER_1, + [], + [self.PROD_3, self.PROD_4], + ) + self.assertEqual(2, len(discounts)) diff --git a/registrasion/tests/test_enabling_condition.py b/registrasion/tests/test_enabling_condition.py index 09433cdd..e7155dea 100644 --- a/registrasion/tests/test_enabling_condition.py +++ b/registrasion/tests/test_enabling_condition.py @@ -4,6 +4,7 @@ from django.core.exceptions import ValidationError from registrasion import models as rego from registrasion.controllers.cart import CartController +from registrasion.controllers.product import ProductController from test_cart import RegistrationCartTestCase @@ -155,3 +156,80 @@ class EnablingConditionTestCases(RegistrationCartTestCase): cart_1.add_to_cart(self.PROD_1, 1) cart_1.add_to_cart(self.PROD_3, 1) # Meets the category condition cart_1.add_to_cart(self.PROD_1, 1) + + def test_available_products_works_with_no_conditions_set(self): + prods = ProductController.available_products( + self.USER_1, + category=self.CAT_1, + ) + + self.assertTrue(self.PROD_1 in prods) + self.assertTrue(self.PROD_2 in prods) + + prods = ProductController.available_products( + self.USER_1, + category=self.CAT_2, + ) + + self.assertTrue(self.PROD_3 in prods) + self.assertTrue(self.PROD_4 in prods) + + prods = ProductController.available_products( + self.USER_1, + products=[self.PROD_1, self.PROD_2, self.PROD_3, self.PROD_4], + ) + + self.assertTrue(self.PROD_1 in prods) + self.assertTrue(self.PROD_2 in prods) + self.assertTrue(self.PROD_3 in prods) + self.assertTrue(self.PROD_4 in prods) + + def test_available_products_on_category_works_when_condition_not_met(self): + self.add_product_enabling_condition(mandatory=False) + + prods = ProductController.available_products( + self.USER_1, + category=self.CAT_1, + ) + + self.assertTrue(self.PROD_1 not in prods) + self.assertTrue(self.PROD_2 in prods) + + def test_available_products_on_category_works_when_condition_is_met(self): + self.add_product_enabling_condition(mandatory=False) + + cart_1 = CartController.for_user(self.USER_1) + cart_1.add_to_cart(self.PROD_2, 1) + + prods = ProductController.available_products( + self.USER_1, + category=self.CAT_1, + ) + + self.assertTrue(self.PROD_1 in prods) + self.assertTrue(self.PROD_2 in prods) + + def test_available_products_on_products_works_when_condition_not_met(self): + self.add_product_enabling_condition(mandatory=False) + + prods = ProductController.available_products( + self.USER_1, + products=[self.PROD_1, self.PROD_2], + ) + + self.assertTrue(self.PROD_1 not in prods) + self.assertTrue(self.PROD_2 in prods) + + def test_available_products_on_products_works_when_condition_is_met(self): + self.add_product_enabling_condition(mandatory=False) + + cart_1 = CartController.for_user(self.USER_1) + cart_1.add_to_cart(self.PROD_2, 1) + + prods = ProductController.available_products( + self.USER_1, + products=[self.PROD_1, self.PROD_2], + ) + + self.assertTrue(self.PROD_1 in prods) + self.assertTrue(self.PROD_2 in prods) diff --git a/registrasion/views.py b/registrasion/views.py index 070f645a..83e80a0d 100644 --- a/registrasion/views.py +++ b/registrasion/views.py @@ -1,7 +1,9 @@ from registrasion import forms from registrasion import models as rego +from registrasion.controllers import discount from registrasion.controllers.cart import CartController from registrasion.controllers.invoice import InvoiceController +from registrasion.controllers.product import ProductController from django.contrib.auth.decorators import login_required from django.core.exceptions import ObjectDoesNotExist @@ -95,19 +97,21 @@ def product_category(request, category_id): category = rego.Category.objects.get(pk=category_id) current_cart = CartController.for_user(request.user) - CategoryForm = forms.CategoryForm(category) - attendee = rego.Attendee.get_instance(request.user) products = rego.Product.objects.filter(category=category) products = products.order_by("order") + products = ProductController.available_products( + request.user, + products=products, + ) + ProductsForm = forms.ProductsForm(products) if request.method == "POST": - cat_form = CategoryForm( + cat_form = ProductsForm( request.POST, request.FILES, prefix=PRODUCTS_FORM_PREFIX) - cat_form.disable_products_for_user(request.user) voucher_form = forms.VoucherForm( request.POST, prefix=VOUCHERS_FORM_PREFIX) @@ -165,14 +169,17 @@ def product_category(request, category_id): quantity = 0 quantities.append((product, quantity)) - initial = CategoryForm.initial_data(quantities) - cat_form = CategoryForm(prefix=PRODUCTS_FORM_PREFIX, initial=initial) - cat_form.disable_products_for_user(request.user) + cat_form = ProductsForm( + prefix=PRODUCTS_FORM_PREFIX, + product_quantities=quantities, + ) voucher_form = forms.VoucherForm(prefix=VOUCHERS_FORM_PREFIX) + discounts = discount.available_discounts(request.user, [], products) data = { "category": category, + "discounts": discounts, "form": cat_form, "voucher_form": voucher_form, }