diff --git a/registrasion/controllers/cart.py b/registrasion/controllers/cart.py index 2ff9f171..283c3119 100644 --- a/registrasion/controllers/cart.py +++ b/registrasion/controllers/cart.py @@ -240,12 +240,11 @@ class CartController(object): by_cat[product.category].append((product, quantity)) # Pre-annotate categories - r = CategoryController.attach_user_remainders(self.cart.user, by_cat) - with_remainders = dict((cat, cat) for cat in r) + remainders = CategoryController.user_remainders(self.cart.user) # Test each category limit here for category in by_cat: - limit = with_remainders[category].remainder + limit = remainders[category.id] # Get the amount so far in the cart to_add = sum(i[1] for i in by_cat[category]) diff --git a/registrasion/controllers/category.py b/registrasion/controllers/category.py index 9db8ca9e..4681f48b 100644 --- a/registrasion/controllers/category.py +++ b/registrasion/controllers/category.py @@ -39,17 +39,16 @@ class CategoryController(object): return set(i.category for i in available) @classmethod - def attach_user_remainders(cls, user, categories): + def user_remainders(cls, user): ''' Return: - queryset(inventory.Product): A queryset containing items from - ``categories``, with an extra attribute -- remainder = the amount - of items from this category that is remaining. + Mapping[int->int]: A dictionary that maps the category ID to the + user's remainder for that category. + ''' - ids = [category.id for category in categories] - categories = inventory.Category.objects.filter(id__in=ids) + categories = inventory.Category.objects.all() cart_filter = ( Q(product__productitem__cart__user=user) & @@ -73,12 +72,4 @@ class CategoryController(object): categories = categories.annotate(remainder=remainder) - return categories - - def user_quantity_remaining(self, user): - ''' Returns the quantity of this product that the user add in the - current cart. ''' - - with_remainders = self.attach_user_remainders(user, [self.category]) - - return with_remainders[0].remainder + return dict((cat.id, cat.remainder) for cat in categories) diff --git a/registrasion/controllers/product.py b/registrasion/controllers/product.py index 0e2e984f..0810902b 100644 --- a/registrasion/controllers/product.py +++ b/registrasion/controllers/product.py @@ -34,16 +34,13 @@ class ProductController(object): if products is not None: all_products = set(itertools.chain(all_products, products)) - categories = set(product.category for product in all_products) - r = CategoryController.attach_user_remainders(user, categories) - cat_quants = dict((c, c) for c in r) - + category_remainders = CategoryController.user_remainders(user) product_remainders = ProductController.user_remainders(user) passed_limits = set( product for product in all_products - if cat_quants[product.category].remainder > 0 + if category_remainders[product.category.id] > 0 if product_remainders[product.id] > 0 )