Adds negative quantity tests to _test_limits, and removes _set_quantity_old.
This commit is contained in:
		
							parent
							
								
									6c9a68dc5b
								
							
						
					
					
						commit
						312fffd137
					
				
					 1 changed files with 48 additions and 56 deletions
				
			
		| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
import collections
 | 
					import collections
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
import discount
 | 
					import discount
 | 
				
			||||||
 | 
					import itertools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.core.exceptions import ObjectDoesNotExist
 | 
					from django.core.exceptions import ObjectDoesNotExist
 | 
				
			||||||
from django.core.exceptions import ValidationError
 | 
					from django.core.exceptions import ValidationError
 | 
				
			||||||
| 
						 | 
					@ -73,25 +74,64 @@ class CartController(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @transaction.atomic
 | 
					    @transaction.atomic
 | 
				
			||||||
    def set_quantities(self, product_quantities):
 | 
					    def set_quantities(self, product_quantities):
 | 
				
			||||||
 | 
					        ''' Sets the quantities on each of the products on each of the
 | 
				
			||||||
 | 
					        products specified. Raises an exception (ValidationError) if a limit
 | 
				
			||||||
 | 
					        is violated. `product_quantities` is an iterable of (product, quantity)
 | 
				
			||||||
 | 
					        pairs. '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        items_in_cart = rego.ProductItem.objects.filter(cart=self.cart)
 | 
					        items_in_cart = rego.ProductItem.objects.filter(cart=self.cart)
 | 
				
			||||||
 | 
					        product_quantities = list(product_quantities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Remove all items that we're updating
 | 
					        # n.b need to add have the existing items first so that the new
 | 
				
			||||||
        items_in_cart.filter(
 | 
					        # items override the old ones.
 | 
				
			||||||
            product__in=(i[0] for i in product_quantities),
 | 
					        all_product_quantities = dict(itertools.chain(
 | 
				
			||||||
        ).delete()
 | 
					            ((i.product, i.quantity) for i in items_in_cart.all()),
 | 
				
			||||||
 | 
					            product_quantities,
 | 
				
			||||||
 | 
					        )).items()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        all_product_quantities = list(product_quantities) + [
 | 
					        # Validate that the limits we're adding are OK
 | 
				
			||||||
            (i.product, i.quantity) for i in items_in_cart.all()
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        self._test_limits(all_product_quantities)
 | 
					        self._test_limits(all_product_quantities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for product, quantity in product_quantities:
 | 
					        for product, quantity in product_quantities:
 | 
				
			||||||
            self._set_quantity_old(product, quantity)
 | 
					            try:
 | 
				
			||||||
 | 
					                product_item = rego.ProductItem.objects.get(
 | 
				
			||||||
 | 
					                    cart=self.cart,
 | 
				
			||||||
 | 
					                    product=product,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                product_item.quantity = quantity
 | 
				
			||||||
 | 
					                product_item.save()
 | 
				
			||||||
 | 
					            except ObjectDoesNotExist:
 | 
				
			||||||
 | 
					                rego.ProductItem.objects.create(
 | 
				
			||||||
 | 
					                    cart=self.cart,
 | 
				
			||||||
 | 
					                    product=product,
 | 
				
			||||||
 | 
					                    quantity=quantity,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        items_in_cart.filter(quantity=0).delete()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.end_batch()
 | 
					        self.end_batch()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _test_limits(self, product_quantities):
 | 
					    def _test_limits(self, product_quantities):
 | 
				
			||||||
 | 
					        ''' Tests that the quantity changes we intend to make do not violate
 | 
				
			||||||
 | 
					        the limits and enabling conditions imposed on the products. '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test each product limit here
 | 
				
			||||||
 | 
					        for product, quantity in product_quantities:
 | 
				
			||||||
 | 
					            if quantity < 0:
 | 
				
			||||||
 | 
					                # TODO: batch errors
 | 
				
			||||||
 | 
					                raise ValidationError("Value must be zero or greater.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            prod = ProductController(product)
 | 
				
			||||||
 | 
					            limit = prod.user_quantity_remaining(self.cart.user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if quantity > limit:
 | 
				
			||||||
 | 
					                # TODO: batch errors
 | 
				
			||||||
 | 
					                raise ValidationError(
 | 
				
			||||||
 | 
					                    "You may only have %d of product: %s" % (
 | 
				
			||||||
 | 
					                        limit, product.name,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Collect by category
 | 
					        # Collect by category
 | 
				
			||||||
        by_cat = collections.defaultdict(list)
 | 
					        by_cat = collections.defaultdict(list)
 | 
				
			||||||
        for product, quantity in product_quantities:
 | 
					        for product, quantity in product_quantities:
 | 
				
			||||||
| 
						 | 
					@ -113,19 +153,6 @@ class CartController(object):
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Test each product limit here
 | 
					 | 
				
			||||||
        for product, quantity in product_quantities:
 | 
					 | 
				
			||||||
            prod = ProductController(product)
 | 
					 | 
				
			||||||
            limit = prod.user_quantity_remaining(self.cart.user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if quantity > limit:
 | 
					 | 
				
			||||||
                # TODO: batch errors
 | 
					 | 
				
			||||||
                raise ValidationError(
 | 
					 | 
				
			||||||
                    "You may only have %d of product: %s" % (
 | 
					 | 
				
			||||||
                        limit, cat.name,
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Test the enabling conditions
 | 
					        # Test the enabling conditions
 | 
				
			||||||
        errs = ConditionController.test_enabling_conditions(
 | 
					        errs = ConditionController.test_enabling_conditions(
 | 
				
			||||||
            self.cart.user,
 | 
					            self.cart.user,
 | 
				
			||||||
| 
						 | 
					@ -142,41 +169,6 @@ class CartController(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.set_quantities(((product, quantity),))
 | 
					        self.set_quantities(((product, quantity),))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _set_quantity_old(self, product, quantity):
 | 
					 | 
				
			||||||
        ''' Sets the _quantity_ of the given _product_ in the cart to the given
 | 
					 | 
				
			||||||
        _quantity_. '''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if quantity < 0:
 | 
					 | 
				
			||||||
            raise ValidationError("Cannot have fewer than 0 items in cart.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            product_item = rego.ProductItem.objects.get(
 | 
					 | 
				
			||||||
                cart=self.cart,
 | 
					 | 
				
			||||||
                product=product)
 | 
					 | 
				
			||||||
            old_quantity = product_item.quantity
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if quantity == 0:
 | 
					 | 
				
			||||||
                product_item.delete()
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
        except ObjectDoesNotExist:
 | 
					 | 
				
			||||||
            if quantity == 0:
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            product_item = rego.ProductItem.objects.create(
 | 
					 | 
				
			||||||
                cart=self.cart,
 | 
					 | 
				
			||||||
                product=product,
 | 
					 | 
				
			||||||
                quantity=0,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            old_quantity = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Validate the addition to the cart
 | 
					 | 
				
			||||||
        adjustment = quantity - old_quantity
 | 
					 | 
				
			||||||
        prod = ProductController(product)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        product_item.quantity = quantity
 | 
					 | 
				
			||||||
        product_item.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_to_cart(self, product, quantity):
 | 
					    def add_to_cart(self, product, quantity):
 | 
				
			||||||
        ''' Adds _quantity_ of the given _product_ to the cart. Raises
 | 
					        ''' Adds _quantity_ of the given _product_ to the cart. Raises
 | 
				
			||||||
        ValidationError if constraints are violated.'''
 | 
					        ValidationError if constraints are violated.'''
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue