From 312fffd137c20cb6e351201f9290f1b81fadf763 Mon Sep 17 00:00:00 2001
From: Christopher Neugebauer <chrisjrn@gmail.com>
Date: Sun, 3 Apr 2016 09:45:39 +1000
Subject: [PATCH] Adds negative quantity tests to _test_limits, and removes
 _set_quantity_old.

---
 registrasion/controllers/cart.py | 104 ++++++++++++++-----------------
 1 file changed, 48 insertions(+), 56 deletions(-)

diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py
index db9c912a..5fab4b30 100644
--- a/registrasion/controllers/cart.py
+++ b/registrasion/controllers/cart.py
@@ -1,6 +1,7 @@
 import collections
 import datetime
 import discount
+import itertools
 
 from django.core.exceptions import ObjectDoesNotExist
 from django.core.exceptions import ValidationError
@@ -73,25 +74,64 @@ class CartController(object):
 
     @transaction.atomic
     def set_quantities(self, product_quantities):
+        ''' Sets the quantities on each of the products on each of the
+        products specified. Raises an exception (ValidationError) if a limit
+        is violated. `product_quantities` is an iterable of (product, quantity)
+        pairs. '''
 
         items_in_cart = rego.ProductItem.objects.filter(cart=self.cart)
+        product_quantities = list(product_quantities)
 
-        # Remove all items that we're updating
-        items_in_cart.filter(
-            product__in=(i[0] for i in product_quantities),
-        ).delete()
+        # n.b need to add have the existing items first so that the new
+        # items override the old ones.
+        all_product_quantities = dict(itertools.chain(
+            ((i.product, i.quantity) for i in items_in_cart.all()),
+            product_quantities,
+        )).items()
 
-        all_product_quantities = list(product_quantities) + [
-            (i.product, i.quantity) for i in items_in_cart.all()
-        ]
+        # Validate that the limits we're adding are OK
         self._test_limits(all_product_quantities)
 
         for product, quantity in product_quantities:
-            self._set_quantity_old(product, quantity)
+            try:
+                product_item = rego.ProductItem.objects.get(
+                    cart=self.cart,
+                    product=product,
+                )
+                product_item.quantity = quantity
+                product_item.save()
+            except ObjectDoesNotExist:
+                rego.ProductItem.objects.create(
+                    cart=self.cart,
+                    product=product,
+                    quantity=quantity,
+                )
+
+        items_in_cart.filter(quantity=0).delete()
 
         self.end_batch()
 
     def _test_limits(self, product_quantities):
+        ''' Tests that the quantity changes we intend to make do not violate
+        the limits and enabling conditions imposed on the products. '''
+
+        # Test each product limit here
+        for product, quantity in product_quantities:
+            if quantity < 0:
+                # TODO: batch errors
+                raise ValidationError("Value must be zero or greater.")
+
+            prod = ProductController(product)
+            limit = prod.user_quantity_remaining(self.cart.user)
+
+            if quantity > limit:
+                # TODO: batch errors
+                raise ValidationError(
+                    "You may only have %d of product: %s" % (
+                        limit, product.name,
+                    )
+                )
+
         # Collect by category
         by_cat = collections.defaultdict(list)
         for product, quantity in product_quantities:
@@ -113,19 +153,6 @@ class CartController(object):
                     )
                 )
 
-        # Test each product limit here
-        for product, quantity in product_quantities:
-            prod = ProductController(product)
-            limit = prod.user_quantity_remaining(self.cart.user)
-
-            if quantity > limit:
-                # TODO: batch errors
-                raise ValidationError(
-                    "You may only have %d of product: %s" % (
-                        limit, cat.name,
-                    )
-                )
-
         # Test the enabling conditions
         errs = ConditionController.test_enabling_conditions(
             self.cart.user,
@@ -142,41 +169,6 @@ class CartController(object):
 
         self.set_quantities(((product, quantity),))
 
-    def _set_quantity_old(self, product, quantity):
-        ''' Sets the _quantity_ of the given _product_ in the cart to the given
-        _quantity_. '''
-
-        if quantity < 0:
-            raise ValidationError("Cannot have fewer than 0 items in cart.")
-
-        try:
-            product_item = rego.ProductItem.objects.get(
-                cart=self.cart,
-                product=product)
-            old_quantity = product_item.quantity
-
-            if quantity == 0:
-                product_item.delete()
-                return
-        except ObjectDoesNotExist:
-            if quantity == 0:
-                return
-
-            product_item = rego.ProductItem.objects.create(
-                cart=self.cart,
-                product=product,
-                quantity=0,
-            )
-
-            old_quantity = 0
-
-        # Validate the addition to the cart
-        adjustment = quantity - old_quantity
-        prod = ProductController(product)
-
-        product_item.quantity = quantity
-        product_item.save()
-
     def add_to_cart(self, product, quantity):
         ''' Adds _quantity_ of the given _product_ to the cart. Raises
         ValidationError if constraints are violated.'''