Merge branch 'query-optimisation'

This commit is contained in:
Christopher Neugebauer 2016-04-29 11:23:17 +10:00
commit 6956c78b0d
14 changed files with 1064 additions and 449 deletions

View file

@ -1,6 +1,6 @@
import collections import collections
import contextlib
import datetime import datetime
import discount
import functools import functools
import itertools import itertools
@ -8,6 +8,7 @@ from django.core.exceptions import ObjectDoesNotExist
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction from django.db import transaction
from django.db.models import Max from django.db.models import Max
from django.db.models import Q
from django.utils import timezone from django.utils import timezone
from registrasion.exceptions import CartValidationError from registrasion.exceptions import CartValidationError
@ -15,19 +16,27 @@ from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from registrasion.models import inventory from registrasion.models import inventory
from category import CategoryController from .category import CategoryController
from conditions import ConditionController from .discount import DiscountController
from product import ProductController from .flag import FlagController
from .product import ProductController
def _modifies_cart(func): def _modifies_cart(func):
''' Decorator that makes the wrapped function raise ValidationError ''' Decorator that makes the wrapped function raise ValidationError
if we're doing something that could modify the cart. ''' if we're doing something that could modify the cart.
It also wraps the execution of this function in a database transaction,
and marks the boundaries of a cart operations batch.
'''
@functools.wraps(func) @functools.wraps(func)
def inner(self, *a, **k): def inner(self, *a, **k):
self._fail_if_cart_is_not_active() self._fail_if_cart_is_not_active()
return func(self, *a, **k) with transaction.atomic():
with CartController.operations_batch(self.cart.user) as mark:
mark.mark = True # Marker that we've modified the cart
return func(self, *a, **k)
return inner return inner
@ -55,13 +64,65 @@ class CartController(object):
) )
return cls(existing) return cls(existing)
# Marks the carts that are currently in batches
_FOR_USER = {}
_BATCH_COUNT = collections.defaultdict(int)
_MODIFIED_CARTS = set()
class _ModificationMarker(object):
pass
@classmethod
@contextlib.contextmanager
def operations_batch(cls, user):
''' Marks the boundary for a batch of operations on a user's cart.
These markers can be nested. Only on exiting the outermost marker will
a batch be ended.
When a batch is ended, discounts are recalculated, and the cart's
revision is increased.
'''
if user not in cls._FOR_USER:
_ctrl = cls.for_user(user)
cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
ctrl, _id = cls._FOR_USER[user]
cls._BATCH_COUNT[_id] += 1
try:
success = False
marker = cls._ModificationMarker()
yield marker
if hasattr(marker, "mark"):
cls._MODIFIED_CARTS.add(_id)
success = True
finally:
cls._BATCH_COUNT[_id] -= 1
# Only end on the outermost batch marker, and only if
# it excited cleanly, and a modification occurred
modified = _id in cls._MODIFIED_CARTS
outermost = cls._BATCH_COUNT[_id] == 0
if modified and outermost and success:
ctrl._end_batch()
cls._MODIFIED_CARTS.remove(_id)
# Clear out the cache on the outermost operation
if outermost:
del cls._FOR_USER[user]
def _fail_if_cart_is_not_active(self): def _fail_if_cart_is_not_active(self):
self.cart.refresh_from_db() self.cart.refresh_from_db()
if self.cart.status != commerce.Cart.STATUS_ACTIVE: if self.cart.status != commerce.Cart.STATUS_ACTIVE:
raise ValidationError("You can only amend active carts.") raise ValidationError("You can only amend active carts.")
@_modifies_cart def _autoextend_reservation(self):
def extend_reservation(self):
''' Updates the cart's time last updated value, which is used to ''' Updates the cart's time last updated value, which is used to
determine whether the cart has reserved the items and discounts it determine whether the cart has reserved the items and discounts it
holds. ''' holds. '''
@ -83,21 +144,25 @@ class CartController(object):
self.cart.time_last_updated = timezone.now() self.cart.time_last_updated = timezone.now()
self.cart.reservation_duration = max(reservations) self.cart.reservation_duration = max(reservations)
@_modifies_cart def _end_batch(self):
def end_batch(self):
''' Performs operations that occur occur at the end of a batch of ''' Performs operations that occur occur at the end of a batch of
product changes/voucher applications etc. product changes/voucher applications etc.
THIS SHOULD BE PRIVATE
You need to call this after you've finished modifying the user's cart.
This is normally done by wrapping a block of code using
``operations_batch``.
''' '''
self.recalculate_discounts() self.cart.refresh_from_db()
self.extend_reservation() self._recalculate_discounts()
self._autoextend_reservation()
self.cart.revision += 1 self.cart.revision += 1
self.cart.save() self.cart.save()
@_modifies_cart @_modifies_cart
@transaction.atomic
def set_quantities(self, product_quantities): def set_quantities(self, product_quantities):
''' Sets the quantities on each of the products on each of the ''' Sets the quantities on each of the products on each of the
products specified. Raises an exception (ValidationError) if a limit products specified. Raises an exception (ValidationError) if a limit
@ -122,24 +187,28 @@ class CartController(object):
# Validate that the limits we're adding are OK # Validate that the limits we're adding are OK
self._test_limits(all_product_quantities) self._test_limits(all_product_quantities)
new_items = []
products = []
for product, quantity in product_quantities: for product, quantity in product_quantities:
try: products.append(product)
product_item = commerce.ProductItem.objects.get(
cart=self.cart,
product=product,
)
product_item.quantity = quantity
product_item.save()
except ObjectDoesNotExist:
commerce.ProductItem.objects.create(
cart=self.cart,
product=product,
quantity=quantity,
)
items_in_cart.filter(quantity=0).delete() if quantity == 0:
continue
self.end_batch() item = commerce.ProductItem(
cart=self.cart,
product=product,
quantity=quantity,
)
new_items.append(item)
to_delete = (
Q(quantity=0) |
Q(product__in=products)
)
items_in_cart.filter(to_delete).delete()
commerce.ProductItem.objects.bulk_create(new_items)
def _test_limits(self, product_quantities): def _test_limits(self, product_quantities):
''' Tests that the quantity changes we intend to make do not violate ''' Tests that the quantity changes we intend to make do not violate
@ -147,13 +216,17 @@ class CartController(object):
errors = [] errors = []
# Pre-annotate products
products = [p for (p, q) in product_quantities]
r = ProductController.attach_user_remainders(self.cart.user, products)
with_remainders = dict((p, p) for p in r)
# Test each product limit here # Test each product limit here
for product, quantity in product_quantities: for product, quantity in product_quantities:
if quantity < 0: if quantity < 0:
errors.append((product, "Value must be zero or greater.")) errors.append((product, "Value must be zero or greater."))
prod = ProductController(product) limit = with_remainders[product].remainder
limit = prod.user_quantity_remaining(self.cart.user)
if quantity > limit: if quantity > limit:
errors.append(( errors.append((
@ -168,10 +241,13 @@ class CartController(object):
for product, quantity in product_quantities: for product, quantity in product_quantities:
by_cat[product.category].append((product, quantity)) by_cat[product.category].append((product, quantity))
# Pre-annotate categories
r = CategoryController.attach_user_remainders(self.cart.user, by_cat)
with_remainders = dict((cat, cat) for cat in r)
# Test each category limit here # Test each category limit here
for category in by_cat: for category in by_cat:
ctrl = CategoryController(category) limit = with_remainders[category].remainder
limit = ctrl.user_quantity_remaining(self.cart.user)
# Get the amount so far in the cart # Get the amount so far in the cart
to_add = sum(i[1] for i in by_cat[category]) to_add = sum(i[1] for i in by_cat[category])
@ -185,7 +261,7 @@ class CartController(object):
)) ))
# Test the flag conditions # Test the flag conditions
errs = ConditionController.test_flags( errs = FlagController.test_flags(
self.cart.user, self.cart.user,
product_quantities=product_quantities, product_quantities=product_quantities,
) )
@ -212,7 +288,6 @@ class CartController(object):
# If successful... # If successful...
self.cart.vouchers.add(voucher) self.cart.vouchers.add(voucher)
self.end_batch()
def _test_voucher(self, voucher): def _test_voucher(self, voucher):
''' Tests whether this voucher is allowed to be applied to this cart. ''' Tests whether this voucher is allowed to be applied to this cart.
@ -294,6 +369,7 @@ class CartController(object):
errors.append(ve) errors.append(ve)
items = commerce.ProductItem.objects.filter(cart=cart) items = commerce.ProductItem.objects.filter(cart=cart)
items = items.select_related("product", "product__category")
product_quantities = list((i.product, i.quantity) for i in items) product_quantities = list((i.product, i.quantity) for i in items)
try: try:
@ -307,19 +383,24 @@ class CartController(object):
self._append_errors(errors, ve) self._append_errors(errors, ve)
# Validate the discounts # Validate the discounts
discount_items = commerce.DiscountItem.objects.filter(cart=cart) # TODO: refactor in terms of available_discounts
seen_discounts = set() # why aren't we doing that here?!
# def available_discounts(cls, user, categories, products):
products = [i.product for i in items]
discounts_with_quantity = DiscountController.available_discounts(
user,
[],
products,
)
discounts = set(i.discount.id for i in discounts_with_quantity)
discount_items = commerce.DiscountItem.objects.filter(cart=cart)
for discount_item in discount_items: for discount_item in discount_items:
discount = discount_item.discount discount = discount_item.discount
if discount in seen_discounts:
continue
seen_discounts.add(discount)
real_discount = conditions.DiscountBase.objects.get_subclass(
pk=discount.pk)
cond = ConditionController.for_condition(real_discount)
if not cond.is_met(user): if discount.id not in discounts:
errors.append( errors.append(
ValidationError("Discounts are no longer available") ValidationError("Discounts are no longer available")
) )
@ -328,7 +409,6 @@ class CartController(object):
raise ValidationError(errors) raise ValidationError(errors)
@_modifies_cart @_modifies_cart
@transaction.atomic
def fix_simple_errors(self): def fix_simple_errors(self):
''' This attempts to fix the easy errors raised by ValidationError. ''' This attempts to fix the easy errors raised by ValidationError.
This includes removing items from the cart that are no longer This includes removing items from the cart that are no longer
@ -360,11 +440,9 @@ class CartController(object):
self.set_quantities(zeros) self.set_quantities(zeros)
@_modifies_cart
@transaction.atomic @transaction.atomic
def recalculate_discounts(self): def _recalculate_discounts(self):
''' Calculates all of the discounts available for this product. ''' Calculates all of the discounts available for this product.'''
'''
# Delete the existing entries. # Delete the existing entries.
commerce.DiscountItem.objects.filter(cart=self.cart).delete() commerce.DiscountItem.objects.filter(cart=self.cart).delete()
@ -374,7 +452,11 @@ class CartController(object):
) )
products = [i.product for i in product_items] products = [i.product for i in product_items]
discounts = discount.available_discounts(self.cart.user, [], products) discounts = DiscountController.available_discounts(
self.cart.user,
[],
products,
)
# The highest-value discounts will apply to the highest-value # The highest-value discounts will apply to the highest-value
# products first. # products first.

View file

@ -1,7 +1,11 @@
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import inventory from registrasion.models import inventory
from django.db.models import Case
from django.db.models import F, Q
from django.db.models import Sum from django.db.models import Sum
from django.db.models import When
from django.db.models import Value
class AllProducts(object): class AllProducts(object):
@ -34,25 +38,47 @@ class CategoryController(object):
return set(i.category for i in available) return set(i.category for i in available)
@classmethod
def attach_user_remainders(cls, user, categories):
'''
Return:
queryset(inventory.Product): A queryset containing items from
``categories``, with an extra attribute -- remainder = the amount
of items from this category that is remaining.
'''
ids = [category.id for category in categories]
categories = inventory.Category.objects.filter(id__in=ids)
cart_filter = (
Q(product__productitem__cart__user=user) &
Q(product__productitem__cart__status=commerce.Cart.STATUS_PAID)
)
quantity = When(
cart_filter,
then='product__productitem__quantity'
)
quantity_or_zero = Case(
quantity,
default=Value(0),
)
remainder = Case(
When(limit_per_user=None, then=Value(99999999)),
default=F('limit_per_user') - Sum(quantity_or_zero),
)
categories = categories.annotate(remainder=remainder)
return categories
def user_quantity_remaining(self, user): def user_quantity_remaining(self, user):
''' Returns the number of items from this category that the user may ''' Returns the quantity of this product that the user add in the
add in the current cart. ''' current cart. '''
cat_limit = self.category.limit_per_user with_remainders = self.attach_user_remainders(user, [self.category])
if cat_limit is None: return with_remainders[0].remainder
# We don't need to waste the following queries
return 99999999
carts = commerce.Cart.objects.filter(
user=user,
status=commerce.Cart.STATUS_PAID,
)
items = commerce.ProductItem.objects.filter(
cart__in=carts,
product__category=self.category,
)
cat_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
return cat_limit - cat_count

View file

@ -1,36 +1,27 @@
import itertools from django.db.models import Case
import operator from django.db.models import F, Q
from collections import defaultdict
from collections import namedtuple
from django.db.models import Sum from django.db.models import Sum
from django.db.models import Value
from django.db.models import When
from django.utils import timezone from django.utils import timezone
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from registrasion.models import inventory
ConditionAndRemainder = namedtuple( _BIG_QUANTITY = 99999999 # A big quantity
"ConditionAndRemainder",
(
"condition",
"remainder",
),
)
class ConditionController(object): class ConditionController(object):
''' Base class for testing conditions that activate Flag ''' Base class for testing conditions that activate Flag
or Discount objects. ''' or Discount objects. '''
def __init__(self): def __init__(self, condition):
pass self.condition = condition
@staticmethod @staticmethod
def for_condition(condition): def _controllers():
CONTROLLERS = { return {
conditions.CategoryFlag: CategoryConditionController, conditions.CategoryFlag: CategoryConditionController,
conditions.IncludedProductDiscount: ProductConditionController, conditions.IncludedProductDiscount: ProductConditionController,
conditions.ProductFlag: ProductConditionController, conditions.ProductFlag: ProductConditionController,
@ -42,137 +33,49 @@ class ConditionController(object):
conditions.VoucherFlag: VoucherConditionController, conditions.VoucherFlag: VoucherConditionController,
} }
@staticmethod
def for_type(cls):
return ConditionController._controllers()[cls]
@staticmethod
def for_condition(condition):
try: try:
return CONTROLLERS[type(condition)](condition) return ConditionController.for_type(type(condition))(condition)
except KeyError: except KeyError:
return ConditionController() return ConditionController()
SINGLE = True
PLURAL = False
NONE = True
SOME = False
MESSAGE = {
NONE: {
SINGLE:
"%(items)s is no longer available to you",
PLURAL:
"%(items)s are no longer available to you",
},
SOME: {
SINGLE:
"Only %(remainder)d of the following item remains: %(items)s",
PLURAL:
"Only %(remainder)d of the following items remain: %(items)s"
},
}
@classmethod @classmethod
def test_flags( def pre_filter(cls, queryset, user):
cls, user, products=None, product_quantities=None): ''' Returns only the flag conditions that might be available for this
''' Evaluates all of the flag conditions on the given products. user. It should hopefully reduce the number of queries that need to be
executed to determine if a flag is met.
If `product_quantities` is supplied, the condition is only met if it If this filtration implements the same query as is_met, then you should
will permit the sum of the product quantities for all of the products be able to implement ``is_met()`` in terms of this.
it covers. Otherwise, it will be met if at least one item can be
accepted.
If all flag conditions pass, an empty list is returned, otherwise Arguments:
a list is returned containing all of the products that are *not
enabled*. '''
if products is not None and product_quantities is not None: queryset (Queryset[c]): The canditate conditions.
raise ValueError("Please specify only products or "
"product_quantities")
elif products is None:
products = set(i[0] for i in product_quantities)
quantities = dict((product, quantity)
for product, quantity in product_quantities)
elif product_quantities is None:
products = set(products)
quantities = {}
# Get the conditions covered by the products themselves user (User): The user for whom we're testing these conditions.
prods = (
product.flagbase_set.select_subclasses()
for product in products
)
# Get the conditions covered by their categories
cats = (
category.flagbase_set.select_subclasses()
for category in set(product.category for product in products)
)
if products: Returns:
# Simplify the query. Queryset[c]: A subset of the conditions that pass the pre-filter
all_conditions = reduce(operator.or_, itertools.chain(prods, cats)) test for this user.
else:
all_conditions = []
# All disable-if-false conditions on a product need to be met '''
do_not_disable = defaultdict(lambda: True)
# At least one enable-if-true condition on a product must be met
do_enable = defaultdict(lambda: False)
# (if either sort of condition is present)
messages = {} # Default implementation does NOTHING.
return queryset
for condition in all_conditions: def passes_filter(self, user):
cond = cls.for_condition(condition) ''' Returns true if the condition passes the filter '''
remainder = cond.user_quantity_remaining(user)
# Get all products covered by this condition, and the products cls = type(self.condition)
# from the categories covered by this condition qs = cls.objects.filter(pk=self.condition.id)
cond_products = condition.products.all() return self.condition in self.pre_filter(qs, user)
from_category = inventory.Product.objects.filter(
category__in=condition.categories.all(),
).all()
all_products = cond_products | from_category
all_products = all_products.select_related("category")
# Remove the products that we aren't asking about
all_products = [
product
for product in all_products
if product in products
]
if quantities: def user_quantity_remaining(self, user, filtered=False):
consumed = sum(quantities[i] for i in all_products)
else:
consumed = 1
met = consumed <= remainder
if not met:
items = ", ".join(str(product) for product in all_products)
base = cls.MESSAGE[remainder == 0][len(all_products) == 1]
message = base % {"items": items, "remainder": remainder}
for product in all_products:
if condition.is_disable_if_false:
do_not_disable[product] &= met
else:
do_enable[product] |= met
if not met and product not in messages:
messages[product] = message
valid = {}
for product in itertools.chain(do_not_disable, do_enable):
if product in do_enable:
# If there's an enable-if-true, we need need of those met too.
# (do_not_disable will default to true otherwise)
valid[product] = do_not_disable[product] and do_enable[product]
elif product in do_not_disable:
# If there's a disable-if-false condition, all must be met
valid[product] = do_not_disable[product]
error_fields = [
(product, messages[product])
for product in valid if not valid[product]
]
return error_fields
def user_quantity_remaining(self, user):
''' Returns the number of items covered by this flag condition the ''' Returns the number of items covered by this flag condition the
user can add to the current cart. This default implementation returns user can add to the current cart. This default implementation returns
a big number if is_met() is true, otherwise 0. a big number if is_met() is true, otherwise 0.
@ -180,144 +83,210 @@ class ConditionController(object):
Either this method, or is_met() must be overridden in subclasses. Either this method, or is_met() must be overridden in subclasses.
''' '''
return 99999999 if self.is_met(user) else 0 return _BIG_QUANTITY if self.is_met(user, filtered) else 0
def is_met(self, user): def is_met(self, user, filtered=False):
''' Returns True if this flag condition is met, otherwise returns ''' Returns True if this flag condition is met, otherwise returns
False. False.
Either this method, or user_quantity_remaining() must be overridden Either this method, or user_quantity_remaining() must be overridden
in subclasses. in subclasses.
Arguments:
user (User): The user for whom this test must be met.
filter (bool): If true, this condition was part of a queryset
returned by pre_filter() for this user.
''' '''
return self.user_quantity_remaining(user) > 0 return self.user_quantity_remaining(user, filtered) > 0
class CategoryConditionController(ConditionController): class IsMetByFilter(object):
def __init__(self, condition): def is_met(self, user, filtered=False):
self.condition = condition ''' Returns True if this flag condition is met, otherwise returns
False. It determines if the condition is met by calling pre_filter
with a queryset containing only self.condition. '''
def is_met(self, user): if filtered:
''' returns True if the user has a product from a category that invokes return True # Why query again?
this condition in one of their carts '''
carts = commerce.Cart.objects.filter(user=user) return self.passes_filter(user)
carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
enabling_products = inventory.Product.objects.filter(
category=self.condition.enabling_category, class RemainderSetByFilter(object):
def user_quantity_remaining(self, user, filtered=True):
''' returns 0 if the date range is violated, otherwise, it will return
the quantity remaining under the stock limit.
The filter for this condition must add an annotation called "remainder"
in order for this to work.
'''
if filtered:
if hasattr(self.condition, "remainder"):
return self.condition.remainder
# Mark self.condition with a remainder
qs = type(self.condition).objects.filter(pk=self.condition.id)
qs = self.pre_filter(qs, user)
if len(qs) > 0:
return qs[0].remainder
else:
return 0
class CategoryConditionController(IsMetByFilter, ConditionController):
@classmethod
def pre_filter(self, queryset, user):
''' Returns all of the items from queryset where the user has a
product from a category invoking that item's condition in one of their
carts. '''
in_user_carts = Q(
enabling_category__product__productitem__cart__user=user
) )
products_count = commerce.ProductItem.objects.filter( released = commerce.Cart.STATUS_RELEASED
cart__in=carts, in_released_carts = Q(
product__in=enabling_products, enabling_category__product__productitem__cart__status=released
).count() )
return products_count > 0 queryset = queryset.filter(in_user_carts)
queryset = queryset.exclude(in_released_carts)
return queryset
class ProductConditionController(ConditionController): class ProductConditionController(IsMetByFilter, ConditionController):
''' Condition tests for ProductFlag and ''' Condition tests for ProductFlag and
IncludedProductDiscount. ''' IncludedProductDiscount. '''
def __init__(self, condition): @classmethod
self.condition = condition def pre_filter(self, queryset, user):
''' Returns all of the items from queryset where the user has a
product invoking that item's condition in one of their carts. '''
def is_met(self, user): in_user_carts = Q(enabling_products__productitem__cart__user=user)
''' returns True if the user has a product that invokes this released = commerce.Cart.STATUS_RELEASED
condition in one of their carts ''' in_released_carts = Q(
enabling_products__productitem__cart__status=released
)
queryset = queryset.filter(in_user_carts)
queryset = queryset.exclude(in_released_carts)
carts = commerce.Cart.objects.filter(user=user) return queryset
carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
products_count = commerce.ProductItem.objects.filter(
cart__in=carts,
product__in=self.condition.enabling_products.all(),
).count()
return products_count > 0
class TimeOrStockLimitConditionController(ConditionController): class TimeOrStockLimitConditionController(
RemainderSetByFilter,
ConditionController,
):
''' Common condition tests for TimeOrStockLimit Flag and ''' Common condition tests for TimeOrStockLimit Flag and
Discount.''' Discount.'''
def __init__(self, ceiling): @classmethod
self.ceiling = ceiling def pre_filter(self, queryset, user):
''' Returns all of the items from queryset where the date falls into
any specified range, but not yet where the stock limit is not yet
reached.'''
def user_quantity_remaining(self, user):
''' returns 0 if the date range is violated, otherwise, it will return
the quantity remaining under the stock limit. '''
# Test date range
if not self._test_date_range():
return 0
return self._get_remaining_stock(user)
def _test_date_range(self):
now = timezone.now() now = timezone.now()
if self.ceiling.start_time is not None: # Keep items with no start time, or start time not yet met.
if now < self.ceiling.start_time: queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now))
return False queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now))
if self.ceiling.end_time is not None: # Filter out items that have been reserved beyond the limits
if now > self.ceiling.end_time: quantity_or_zero = self._calculate_quantities(user)
return False
return True remainder = Case(
When(limit=None, then=Value(_BIG_QUANTITY)),
default=F("limit") - Sum(quantity_or_zero),
)
def _get_remaining_stock(self, user): queryset = queryset.annotate(remainder=remainder)
''' Returns the stock that remains under this ceiling, excluding the queryset = queryset.filter(remainder__gt=0)
user's current cart. '''
if self.ceiling.limit is None: return queryset
return 99999999
# We care about all reserved carts, but not the user's current cart @classmethod
def _relevant_carts(cls, user):
reserved_carts = commerce.Cart.reserved_carts() reserved_carts = commerce.Cart.reserved_carts()
reserved_carts = reserved_carts.exclude( reserved_carts = reserved_carts.exclude(
user=user, user=user,
status=commerce.Cart.STATUS_ACTIVE, status=commerce.Cart.STATUS_ACTIVE,
) )
return reserved_carts
items = self._items()
items = items.filter(cart__in=reserved_carts)
count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
return self.ceiling.limit - count
class TimeOrStockLimitFlagController( class TimeOrStockLimitFlagController(
TimeOrStockLimitConditionController): TimeOrStockLimitConditionController):
def _items(self): @classmethod
category_products = inventory.Product.objects.filter( def _calculate_quantities(cls, user):
category__in=self.ceiling.categories.all(), reserved_carts = cls._relevant_carts(user)
)
products = self.ceiling.products.all() | category_products
product_items = commerce.ProductItem.objects.filter( # Calculate category lines
product__in=products.all(), item_cats = F('categories__product__productitem__product__category')
reserved_category_products = (
Q(categories=item_cats) &
Q(categories__product__productitem__cart__in=reserved_carts)
) )
return product_items
# Calculate product lines
reserved_products = (
Q(products=F('products__productitem__product')) &
Q(products__productitem__cart__in=reserved_carts)
)
category_quantity_in_reserved_carts = When(
reserved_category_products,
then="categories__product__productitem__quantity",
)
product_quantity_in_reserved_carts = When(
reserved_products,
then="products__productitem__quantity",
)
quantity_or_zero = Case(
category_quantity_in_reserved_carts,
product_quantity_in_reserved_carts,
default=Value(0),
)
return quantity_or_zero
class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController): class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController):
def _items(self): @classmethod
discount_items = commerce.DiscountItem.objects.filter( def _calculate_quantities(cls, user):
discount=self.ceiling, reserved_carts = cls._relevant_carts(user)
quantity_in_reserved_carts = When(
discountitem__cart__in=reserved_carts,
then="discountitem__quantity"
) )
return discount_items
quantity_or_zero = Case(
quantity_in_reserved_carts,
default=Value(0)
)
return quantity_or_zero
class VoucherConditionController(ConditionController): class VoucherConditionController(IsMetByFilter, ConditionController):
''' Condition test for VoucherFlag and VoucherDiscount.''' ''' Condition test for VoucherFlag and VoucherDiscount.'''
def __init__(self, condition): @classmethod
self.condition = condition def pre_filter(self, queryset, user):
''' Returns all of the items from queryset where the user has entered
a voucher that invokes that item's condition in one of their carts. '''
def is_met(self, user): return queryset.filter(voucher__cart__user=user)
''' returns True if the user has the given voucher attached. '''
carts_count = commerce.Cart.objects.filter(
user=user,
vouchers=self.condition.voucher,
).count()
return carts_count > 0

View file

@ -4,7 +4,11 @@ from conditions import ConditionController
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from django.db.models import Case
from django.db.models import F, Q
from django.db.models import Sum from django.db.models import Sum
from django.db.models import Value
from django.db.models import When
class DiscountAndQuantity(object): class DiscountAndQuantity(object):
@ -38,80 +42,158 @@ class DiscountAndQuantity(object):
) )
def available_discounts(user, categories, products): class DiscountController(object):
''' Returns all discounts available to this user for the given categories
and products. The discounts also list the available quantity for this user,
not including products that are pending purchase. '''
# discounts that match provided categories @classmethod
category_discounts = conditions.DiscountForCategory.objects.filter( def available_discounts(cls, user, categories, products):
category__in=categories ''' Returns all discounts available to this user for the given
) categories and products. The discounts also list the available quantity
# discounts that match provided products for this user, not including products that are pending purchase. '''
product_discounts = conditions.DiscountForProduct.objects.filter(
product__in=products
)
# discounts that match categories for provided products
product_category_discounts = conditions.DiscountForCategory.objects.filter(
category__in=(product.category for product in products)
)
# (Not relevant: discounts that match products in provided categories)
product_discounts = product_discounts.select_related( filtered_clauses = cls._filtered_discounts(user, categories, products)
"product",
"product__category",
)
all_category_discounts = category_discounts | product_category_discounts discounts = []
all_category_discounts = all_category_discounts.select_related(
"category",
)
# The set of all potential discounts # Markers so that we don't need to evaluate given conditions
potential_discounts = set(itertools.chain( # more than once
product_discounts, accepted_discounts = set()
all_category_discounts, failed_discounts = set()
))
discounts = [] for clause in filtered_clauses:
discount = clause.discount
cond = ConditionController.for_condition(discount)
# Markers so that we don't need to evaluate given conditions more than once past_use_count = clause.past_use_count
accepted_discounts = set() if past_use_count >= clause.quantity:
failed_discounts = set() # This clause has exceeded its use count
pass
elif discount not in failed_discounts:
# This clause is still available
is_accepted = discount in accepted_discounts
if is_accepted or cond.is_met(user, filtered=True):
# This clause is valid for this user
discounts.append(DiscountAndQuantity(
discount=discount,
clause=clause,
quantity=clause.quantity - past_use_count,
))
accepted_discounts.add(discount)
else:
# This clause is not valid for this user
failed_discounts.add(discount)
return discounts
for discount in potential_discounts: @classmethod
real_discount = conditions.DiscountBase.objects.get_subclass( def _filtered_discounts(cls, user, categories, products):
pk=discount.discount.pk, '''
Returns:
Sequence[discountbase]: All discounts that passed the filter
function.
'''
types = list(ConditionController._controllers())
discounttypes = [
i for i in types if issubclass(i, conditions.DiscountBase)
]
# discounts that match provided categories
category_discounts = conditions.DiscountForCategory.objects.filter(
category__in=categories
) )
cond = ConditionController.for_condition(real_discount) # discounts that match provided products
product_discounts = conditions.DiscountForProduct.objects.filter(
# Count the past uses of the given discount item. product__in=products
# If this user has exceeded the limit for the clause, this clause
# is not available any more.
past_uses = commerce.DiscountItem.objects.filter(
cart__user=user,
cart__status=commerce.Cart.STATUS_PAID, # Only past carts count
discount=real_discount,
) )
agg = past_uses.aggregate(Sum("quantity")) # discounts that match categories for provided products
past_use_count = agg["quantity__sum"] product_category_discounts = conditions.DiscountForCategory.objects
if past_use_count is None: product_category_discounts = product_category_discounts.filter(
past_use_count = 0 category__in=(product.category for product in products)
)
# (Not relevant: discounts that match products in provided categories)
if past_use_count >= discount.quantity: product_discounts = product_discounts.select_related(
# This clause has exceeded its use count "product",
pass "product__category",
elif real_discount not in failed_discounts: )
# This clause is still available
if real_discount in accepted_discounts or cond.is_met(user): all_category_discounts = (
# This clause is valid for this user category_discounts | product_category_discounts
discounts.append(DiscountAndQuantity( )
discount=real_discount, all_category_discounts = all_category_discounts.select_related(
clause=discount, "category",
quantity=discount.quantity - past_use_count, )
))
accepted_discounts.add(real_discount) valid_discounts = conditions.DiscountBase.objects.filter(
else: Q(discountforproduct__in=product_discounts) |
# This clause is not valid for this user Q(discountforcategory__in=all_category_discounts)
failed_discounts.add(real_discount) )
return discounts
all_subsets = []
for discounttype in discounttypes:
discounts = discounttype.objects.filter(id__in=valid_discounts)
ctrl = ConditionController.for_type(discounttype)
discounts = ctrl.pre_filter(discounts, user)
all_subsets.append(discounts)
filtered_discounts = list(itertools.chain(*all_subsets))
# Map from discount key to itself
# (contains annotations needed in the future)
from_filter = dict((i.id, i) for i in filtered_discounts)
clause_sets = (
product_discounts.filter(discount__in=filtered_discounts),
all_category_discounts.filter(discount__in=filtered_discounts),
)
clause_sets = (
cls._annotate_with_past_uses(i, user) for i in clause_sets
)
# The set of all potential discount clauses
discount_clauses = set(itertools.chain(*clause_sets))
# Replace discounts with the filtered ones
# These are the correct subclasses (saves query later on), and have
# correct annotations from filters if necessary.
for clause in discount_clauses:
clause.discount = from_filter[clause.discount.id]
return discount_clauses
@classmethod
def _annotate_with_past_uses(cls, queryset, user):
''' Annotates the queryset with a usage count for that discount claus
by the given user. '''
if queryset.model == conditions.DiscountForCategory:
matches = (
Q(category=F('discount__discountitem__product__category'))
)
elif queryset.model == conditions.DiscountForProduct:
matches = (
Q(product=F('discount__discountitem__product'))
)
in_carts = (
Q(discount__discountitem__cart__user=user) &
Q(discount__discountitem__cart__status=commerce.Cart.STATUS_PAID)
)
past_use_quantity = When(
in_carts & matches,
then="discount__discountitem__quantity",
)
past_use_quantity_or_zero = Case(
past_use_quantity,
default=Value(0),
)
queryset = queryset.annotate(
past_use_count=Sum(past_use_quantity_or_zero)
)
return queryset

View file

@ -0,0 +1,264 @@
import itertools
import operator
from collections import defaultdict
from collections import namedtuple
from django.db.models import Count
from django.db.models import Q
from .conditions import ConditionController
from registrasion.models import conditions
from registrasion.models import inventory
class FlagController(object):
SINGLE = True
PLURAL = False
NONE = True
SOME = False
MESSAGE = {
NONE: {
SINGLE:
"%(items)s is no longer available to you",
PLURAL:
"%(items)s are no longer available to you",
},
SOME: {
SINGLE:
"Only %(remainder)d of the following item remains: %(items)s",
PLURAL:
"Only %(remainder)d of the following items remain: %(items)s"
},
}
@classmethod
def test_flags(
cls, user, products=None, product_quantities=None):
''' Evaluates all of the flag conditions on the given products.
If `product_quantities` is supplied, the condition is only met if it
will permit the sum of the product quantities for all of the products
it covers. Otherwise, it will be met if at least one item can be
accepted.
If all flag conditions pass, an empty list is returned, otherwise
a list is returned containing all of the products that are *not
enabled*. '''
print "GREPME: test_flags()"
if products is not None and product_quantities is not None:
raise ValueError("Please specify only products or "
"product_quantities")
elif products is None:
products = set(i[0] for i in product_quantities)
quantities = dict((product, quantity)
for product, quantity in product_quantities)
elif product_quantities is None:
products = set(products)
quantities = {}
if products:
# Simplify the query.
all_conditions = cls._filtered_flags(user, products)
else:
all_conditions = []
# All disable-if-false conditions on a product need to be met
do_not_disable = defaultdict(lambda: True)
# At least one enable-if-true condition on a product must be met
do_enable = defaultdict(lambda: False)
# (if either sort of condition is present)
# Count the number of conditions for a product
dif_count = defaultdict(int)
eit_count = defaultdict(int)
messages = {}
for condition in all_conditions:
cond = ConditionController.for_condition(condition)
remainder = cond.user_quantity_remaining(user, filtered=True)
# Get all products covered by this condition, and the products
# from the categories covered by this condition
ids = [product.id for product in products]
all_products = inventory.Product.objects.filter(id__in=ids)
cond = (
Q(flagbase_set=condition) |
Q(category__in=condition.categories.all())
)
all_products = all_products.filter(cond)
all_products = all_products.select_related("category")
if quantities:
consumed = sum(quantities[i] for i in all_products)
else:
consumed = 1
met = consumed <= remainder
if not met:
items = ", ".join(str(product) for product in all_products)
base = cls.MESSAGE[remainder == 0][len(all_products) == 1]
message = base % {"items": items, "remainder": remainder}
for product in all_products:
if condition.is_disable_if_false:
do_not_disable[product] &= met
dif_count[product] += 1
else:
do_enable[product] |= met
eit_count[product] += 1
if not met and product not in messages:
messages[product] = message
total_flags = FlagCounter.count()
valid = {}
# the problem is that now, not every condition falls into
# do_not_disable or do_enable '''
# You should look into this, chris :)
for product in products:
if quantities:
if quantities[product] == 0:
continue
f = total_flags.get(product)
if f.dif > 0 and f.dif != dif_count[product]:
do_not_disable[product] = False
if product not in messages:
messages[product] = "Some disable-if-false " \
"conditions were not met"
if f.eit > 0 and product not in do_enable:
do_enable[product] = False
if product not in messages:
messages[product] = "Some enable-if-true " \
"conditions were not met"
for product in itertools.chain(do_not_disable, do_enable):
f = total_flags.get(product)
if product in do_enable:
# If there's an enable-if-true, we need need of those met too.
# (do_not_disable will default to true otherwise)
valid[product] = do_not_disable[product] and do_enable[product]
elif product in do_not_disable:
# If there's a disable-if-false condition, all must be met
valid[product] = do_not_disable[product]
error_fields = [
(product, messages[product])
for product in valid if not valid[product]
]
return error_fields
@classmethod
def _filtered_flags(cls, user, products):
'''
Returns:
Sequence[flagbase]: All flags that passed the filter function.
'''
types = list(ConditionController._controllers())
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
# Get all flags for the products and categories.
prods = (
product.flagbase_set.all()
for product in products
)
cats = (
category.flagbase_set.all()
for category in set(product.category for product in products)
)
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
all_subsets = []
for flagtype in flagtypes:
flags = flagtype.objects.filter(id__in=all_flags)
ctrl = ConditionController.for_type(flagtype)
flags = ctrl.pre_filter(flags, user)
all_subsets.append(flags)
return itertools.chain(*all_subsets)
ConditionAndRemainder = namedtuple(
"ConditionAndRemainder",
(
"condition",
"remainder",
),
)
_FlagCounter = namedtuple(
"_FlagCounter",
(
"products",
"categories",
),
)
_ConditionsCount = namedtuple(
"ConditionsCount",
(
"dif",
"eit",
),
)
# TODO: this should be cacheable.
class FlagCounter(_FlagCounter):
@classmethod
def count(cls):
# Get the count of how many conditions should exist per product
flagbases = conditions.FlagBase.objects
types = (
conditions.FlagBase.ENABLE_IF_TRUE,
conditions.FlagBase.DISABLE_IF_FALSE,
)
keys = ("eit", "dif")
flags = [
flagbases.filter(
condition=condition_type
).values(
'products', 'categories'
).annotate(
count=Count('id')
)
for condition_type in types
]
cats = defaultdict(lambda: defaultdict(int))
prod = defaultdict(lambda: defaultdict(int))
for key, flagcounts in zip(keys, flags):
for row in flagcounts:
if row["products"] is not None:
prod[row["products"]][key] = row["count"]
if row["categories"] is not None:
cats[row["categories"]][key] = row["count"]
return cls(products=prod, categories=cats)
def get(self, product):
p = self.products[product.id]
c = self.categories[product.category.id]
eit = p["eit"] + c["eit"]
dif = p["dif"] + c["dif"]
return _ConditionsCount(dif=dif, eit=eit)

View file

@ -29,6 +29,7 @@ class InvoiceController(ForId, object):
If such an invoice does not exist, the cart is validated, and if valid, If such an invoice does not exist, the cart is validated, and if valid,
an invoice is generated.''' an invoice is generated.'''
cart.refresh_from_db()
try: try:
invoice = commerce.Invoice.objects.exclude( invoice = commerce.Invoice.objects.exclude(
status=commerce.Invoice.STATUS_VOID, status=commerce.Invoice.STATUS_VOID,
@ -74,6 +75,8 @@ class InvoiceController(ForId, object):
def _generate(cls, cart): def _generate(cls, cart):
''' Generates an invoice for the given cart. ''' ''' Generates an invoice for the given cart. '''
cart.refresh_from_db()
issued = timezone.now() issued = timezone.now()
reservation_limit = cart.reservation_duration + cart.time_last_updated reservation_limit = cart.reservation_duration + cart.time_last_updated
# Never generate a due time that is before the issue time # Never generate a due time that is before the issue time
@ -96,6 +99,10 @@ class InvoiceController(ForId, object):
) )
product_items = commerce.ProductItem.objects.filter(cart=cart) product_items = commerce.ProductItem.objects.filter(cart=cart)
product_items = product_items.select_related(
"product",
"product__category",
)
if len(product_items) == 0: if len(product_items) == 0:
raise ValidationError("Your cart is empty.") raise ValidationError("Your cart is empty.")
@ -103,29 +110,41 @@ class InvoiceController(ForId, object):
product_items = product_items.order_by( product_items = product_items.order_by(
"product__category__order", "product__order" "product__category__order", "product__order"
) )
discount_items = commerce.DiscountItem.objects.filter(cart=cart) discount_items = commerce.DiscountItem.objects.filter(cart=cart)
discount_items = discount_items.select_related(
"discount",
"product",
"product__category",
)
line_items = []
invoice_value = Decimal() invoice_value = Decimal()
for item in product_items: for item in product_items:
product = item.product product = item.product
line_item = commerce.LineItem.objects.create( line_item = commerce.LineItem(
invoice=invoice, invoice=invoice,
description="%s - %s" % (product.category.name, product.name), description="%s - %s" % (product.category.name, product.name),
quantity=item.quantity, quantity=item.quantity,
price=product.price, price=product.price,
product=product, product=product,
) )
line_items.append(line_item)
invoice_value += line_item.quantity * line_item.price invoice_value += line_item.quantity * line_item.price
for item in discount_items: for item in discount_items:
line_item = commerce.LineItem.objects.create( line_item = commerce.LineItem(
invoice=invoice, invoice=invoice,
description=item.discount.description, description=item.discount.description,
quantity=item.quantity, quantity=item.quantity,
price=cls.resolve_discount_value(item) * -1, price=cls.resolve_discount_value(item) * -1,
product=item.product, product=item.product,
) )
line_items.append(line_item)
invoice_value += line_item.quantity * line_item.price invoice_value += line_item.quantity * line_item.price
commerce.LineItem.objects.bulk_create(line_items)
invoice.value = invoice_value invoice.value = invoice_value
invoice.save() invoice.save()
@ -251,6 +270,9 @@ class InvoiceController(ForId, object):
def _invoice_matches_cart(self): def _invoice_matches_cart(self):
''' Returns true if there is no cart, or if the revision of this ''' Returns true if there is no cart, or if the revision of this
invoice matches the current revision of the cart. ''' invoice matches the current revision of the cart. '''
self._refresh()
cart = self.invoice.cart cart = self.invoice.cart
if not cart: if not cart:
return True return True

View file

@ -1,11 +1,16 @@
import itertools import itertools
from django.db.models import Case
from django.db.models import F, Q
from django.db.models import Sum from django.db.models import Sum
from django.db.models import When
from django.db.models import Value
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import inventory from registrasion.models import inventory
from category import CategoryController from .category import CategoryController
from conditions import ConditionController from .flag import FlagController
class ProductController(object): class ProductController(object):
@ -16,9 +21,7 @@ class ProductController(object):
@classmethod @classmethod
def available_products(cls, user, category=None, products=None): def available_products(cls, user, category=None, products=None):
''' Returns a list of all of the products that are available per ''' Returns a list of all of the products that are available per
flag conditions from the given categories. flag conditions from the given categories. '''
TODO: refactor so that all conditions are tested here and
can_add_with_flags calls this method. '''
if category is None and products is None: if category is None and products is None:
raise ValueError("You must provide products or a category") raise ValueError("You must provide products or a category")
@ -31,22 +34,21 @@ class ProductController(object):
if products is not None: if products is not None:
all_products = set(itertools.chain(all_products, products)) all_products = set(itertools.chain(all_products, products))
cat_quants = dict( categories = set(product.category for product in all_products)
( r = CategoryController.attach_user_remainders(user, categories)
category, cat_quants = dict((c, c) for c in r)
CategoryController(category).user_quantity_remaining(user),
) r = ProductController.attach_user_remainders(user, all_products)
for category in set(product.category for product in all_products) prod_quants = dict((p, p) for p in r)
)
passed_limits = set( passed_limits = set(
product product
for product in all_products for product in all_products
if cat_quants[product.category] > 0 if cat_quants[product.category].remainder > 0
if cls(product).user_quantity_remaining(user) > 0 if prod_quants[product].remainder > 0
) )
failed_and_messages = ConditionController.test_flags( failed_and_messages = FlagController.test_flags(
user, products=passed_limits user, products=passed_limits
) )
failed_conditions = set(i[0] for i in failed_and_messages) failed_conditions = set(i[0] for i in failed_and_messages)
@ -56,26 +58,47 @@ class ProductController(object):
return out return out
@classmethod
def attach_user_remainders(cls, user, products):
'''
Return:
queryset(inventory.Product): A queryset containing items from
``product``, with an extra attribute -- remainder = the amount of
this item that is remaining.
'''
ids = [product.id for product in products]
products = inventory.Product.objects.filter(id__in=ids)
cart_filter = (
Q(productitem__cart__user=user) &
Q(productitem__cart__status=commerce.Cart.STATUS_PAID)
)
quantity = When(
cart_filter,
then='productitem__quantity'
)
quantity_or_zero = Case(
quantity,
default=Value(0),
)
remainder = Case(
When(limit_per_user=None, then=Value(99999999)),
default=F('limit_per_user') - Sum(quantity_or_zero),
)
products = products.annotate(remainder=remainder)
return products
def user_quantity_remaining(self, user): def user_quantity_remaining(self, user):
''' Returns the quantity of this product that the user add in the ''' Returns the quantity of this product that the user add in the
current cart. ''' current cart. '''
prod_limit = self.product.limit_per_user with_remainders = self.attach_user_remainders(user, [self.product])
if prod_limit is None: return with_remainders[0].remainder
# Don't need to run the remaining queries
return 999999 # We can do better
carts = commerce.Cart.objects.filter(
user=user,
status=commerce.Cart.STATUS_PAID,
)
items = commerce.ProductItem.objects.filter(
cart__in=carts,
product=self.product,
)
prod_count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
return prod_limit - prod_count

View file

@ -4,7 +4,11 @@ from registrasion.controllers.category import CategoryController
from collections import namedtuple from collections import namedtuple
from django import template from django import template
from django.db.models import Case
from django.db.models import Q
from django.db.models import Sum from django.db.models import Sum
from django.db.models import When
from django.db.models import Value
register = template.Library() register = template.Library()
@ -99,20 +103,33 @@ def items_purchased(context, category=None):
''' '''
all_items = commerce.ProductItem.objects.filter( in_cart = (
cart__user=context.request.user, Q(productitem__cart__user=context.request.user) &
cart__status=commerce.Cart.STATUS_PAID, Q(productitem__cart__status=commerce.Cart.STATUS_PAID)
).select_related("product", "product__category") )
quantities_in_cart = When(
in_cart,
then="productitem__quantity",
)
quantities_or_zero = Case(
quantities_in_cart,
default=Value(0),
)
products = inventory.Product.objects
if category: if category:
all_items = all_items.filter(product__category=category) products = products.filter(category=category)
products = products.select_related("category")
products = products.annotate(quantity=Sum(quantities_or_zero))
products = products.filter(quantity__gt=0)
pq = all_items.values("product").annotate(quantity=Sum("quantity")).all()
products = inventory.Product.objects.all()
out = [] out = []
for item in pq: for prod in products:
prod = products.get(pk=item["product"]) out.append(ProductAndQuantity(prod, prod.quantity))
out.append(ProductAndQuantity(prod, item["quantity"]))
return out return out

View file

@ -26,7 +26,7 @@ class RegistrationCartTestCase(SetTimeMixin, TestCase):
super(RegistrationCartTestCase, self).setUp() super(RegistrationCartTestCase, self).setUp()
def tearDown(self): def tearDown(self):
if False: if True:
# If you're seeing segfaults in tests, enable this. # If you're seeing segfaults in tests, enable this.
call_command( call_command(
'flush', 'flush',

View file

@ -6,6 +6,8 @@ from django.core.exceptions import ValidationError
from controller_helpers import TestingCartController from controller_helpers import TestingCartController
from test_cart import RegistrationCartTestCase from test_cart import RegistrationCartTestCase
from registrasion.controllers.discount import DiscountController
from registrasion.controllers.product import ProductController
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
@ -135,6 +137,43 @@ class CeilingsTestCases(RegistrationCartTestCase):
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
first_cart.validate_cart() first_cart.validate_cart()
def test_discount_ceiling_aggregates_products(self):
# Create two carts, add 1xprod_1 to each. Ceiling should disappear
# after second.
self.make_discount_ceiling(
"Multi-product limit discount ceiling",
limit=2,
)
for i in xrange(2):
cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1)
cart.next_cart()
discounts = DiscountController.available_discounts(
self.USER_1,
[],
[self.PROD_1],
)
self.assertEqual(0, len(discounts))
def test_flag_ceiling_aggregates_products(self):
# Create two carts, add 1xprod_1 to each. Ceiling should disappear
# after second.
self.make_ceiling("Multi-product limit ceiling", limit=2)
for i in xrange(2):
cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1)
cart.next_cart()
products = ProductController.available_products(
self.USER_1,
products=[self.PROD_1],
)
self.assertEqual(0, len(products))
def test_items_released_from_ceiling_by_refund(self): def test_items_released_from_ceiling_by_refund(self):
self.make_ceiling("Limit ceiling", limit=1) self.make_ceiling("Limit ceiling", limit=1)

View file

@ -4,7 +4,7 @@ from decimal import Decimal
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import conditions from registrasion.models import conditions
from registrasion.controllers import discount from registrasion.controllers.discount import DiscountController
from controller_helpers import TestingCartController from controller_helpers import TestingCartController
from test_cart import RegistrationCartTestCase from test_cart import RegistrationCartTestCase
@ -243,22 +243,30 @@ class DiscountTestCase(RegistrationCartTestCase):
# The discount is applied. # The discount is applied.
self.assertEqual(1, len(discount_items)) self.assertEqual(1, len(discount_items))
# Tests for the discount.available_discounts enumerator # Tests for the DiscountController.available_discounts enumerator
def test_enumerate_no_discounts_for_no_input(self): def test_enumerate_no_discounts_for_no_input(self):
discounts = discount.available_discounts(self.USER_1, [], []) discounts = DiscountController.available_discounts(
self.USER_1,
[],
[],
)
self.assertEqual(0, len(discounts)) self.assertEqual(0, len(discounts))
def test_enumerate_no_discounts_if_condition_not_met(self): def test_enumerate_no_discounts_if_condition_not_met(self):
self.add_discount_prod_1_includes_cat_2(quantity=1) self.add_discount_prod_1_includes_cat_2(quantity=1)
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_3], [self.PROD_3],
) )
self.assertEqual(0, len(discounts)) self.assertEqual(0, len(discounts))
discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_2],
[],
)
self.assertEqual(0, len(discounts)) self.assertEqual(0, len(discounts))
def test_category_discount_appears_once_if_met_twice(self): def test_category_discount_appears_once_if_met_twice(self):
@ -267,7 +275,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[self.CAT_2], [self.CAT_2],
[self.PROD_3], [self.PROD_3],
@ -280,7 +288,11 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_2],
[],
)
self.assertEqual(1, len(discounts)) self.assertEqual(1, len(discounts))
def test_category_discount_appears_with_product(self): def test_category_discount_appears_with_product(self):
@ -289,7 +301,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_3], [self.PROD_3],
@ -302,7 +314,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_3, self.PROD_4] [self.PROD_3, self.PROD_4]
@ -315,7 +327,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_2], [self.PROD_2],
@ -328,7 +340,11 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts(self.USER_1, [self.CAT_1], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_1],
[],
)
self.assertEqual(0, len(discounts)) self.assertEqual(0, len(discounts))
def test_discount_quantity_is_correct_before_first_purchase(self): def test_discount_quantity_is_correct_before_first_purchase(self):
@ -338,7 +354,11 @@ class DiscountTestCase(RegistrationCartTestCase):
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity
discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_2],
[],
)
self.assertEqual(2, discounts[0].quantity) self.assertEqual(2, discounts[0].quantity)
cart.next_cart() cart.next_cart()
@ -349,32 +369,63 @@ class DiscountTestCase(RegistrationCartTestCase):
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity cart.add_to_cart(self.PROD_3, 1) # Exhaust the quantity
discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_2],
[],
)
self.assertEqual(1, discounts[0].quantity) self.assertEqual(1, discounts[0].quantity)
cart.next_cart() cart.next_cart()
def test_discount_is_gone_after_quantity_exhausted(self): def test_discount_is_gone_after_quantity_exhausted(self):
self.test_discount_quantity_is_correct_after_first_purchase() self.test_discount_quantity_is_correct_after_first_purchase()
discounts = discount.available_discounts(self.USER_1, [self.CAT_2], []) discounts = DiscountController.available_discounts(
self.USER_1,
[self.CAT_2],
[],
)
self.assertEqual(0, len(discounts)) self.assertEqual(0, len(discounts))
def test_product_discount_enabled_twice_appears_twice(self): def test_product_discount_enabled_twice_appears_twice(self):
self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2) self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=2)
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_3, self.PROD_4], [self.PROD_3, self.PROD_4],
) )
self.assertEqual(2, len(discounts)) self.assertEqual(2, len(discounts))
def test_product_discount_applied_on_different_invoices(self):
# quantity=1 means "quantity per product"
self.add_discount_prod_1_includes_prod_3_and_prod_4(quantity=1)
cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = DiscountController.available_discounts(
self.USER_1,
[],
[self.PROD_3, self.PROD_4],
)
self.assertEqual(2, len(discounts))
# adding one of PROD_3 should make it no longer an available discount.
cart.add_to_cart(self.PROD_3, 1)
cart.next_cart()
# should still have (and only have) the discount for prod_4
discounts = DiscountController.available_discounts(
self.USER_1,
[],
[self.PROD_3, self.PROD_4],
)
self.assertEqual(1, len(discounts))
def test_discounts_are_released_by_refunds(self): def test_discounts_are_released_by_refunds(self):
self.add_discount_prod_1_includes_prod_2(quantity=2) self.add_discount_prod_1_includes_prod_2(quantity=2)
cart = TestingCartController.for_user(self.USER_1) cart = TestingCartController.for_user(self.USER_1)
cart.add_to_cart(self.PROD_1, 1) # Enable the discount cart.add_to_cart(self.PROD_1, 1) # Enable the discount
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_2], [self.PROD_2],
@ -388,7 +439,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart.next_cart() cart.next_cart()
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_2], [self.PROD_2],
@ -398,7 +449,7 @@ class DiscountTestCase(RegistrationCartTestCase):
cart.cart.status = commerce.Cart.STATUS_RELEASED cart.cart.status = commerce.Cart.STATUS_RELEASED
cart.cart.save() cart.cart.save()
discounts = discount.available_discounts( discounts = DiscountController.available_discounts(
self.USER_1, self.USER_1,
[], [],
[self.PROD_2], [self.PROD_2],

View file

@ -25,3 +25,33 @@ def all_arguments_optional(ntcls):
) )
return ntcls return ntcls
def lazy(function, *args, **kwargs):
''' Produces a callable so that functions can be lazily evaluated in
templates.
Arguments:
function (callable): The function to call at evaluation time.
args: Positional arguments, passed directly to ``function``.
kwargs: Keyword arguments, passed directly to ``function``.
Return:
callable: A callable that will evaluate a call to ``function`` with
the specified arguments.
'''
NOT_EVALUATED = object()
retval = [NOT_EVALUATED]
def evaluate():
if retval[0] is NOT_EVALUATED:
retval[0] = function(*args, **kwargs)
return retval[0]
return evaluate

View file

@ -5,7 +5,7 @@ from registrasion import util
from registrasion.models import commerce from registrasion.models import commerce
from registrasion.models import inventory from registrasion.models import inventory
from registrasion.models import people from registrasion.models import people
from registrasion.controllers import discount from registrasion.controllers.discount import DiscountController
from registrasion.controllers.cart import CartController from registrasion.controllers.cart import CartController
from registrasion.controllers.credit_note import CreditNoteController from registrasion.controllers.credit_note import CreditNoteController
from registrasion.controllers.invoice import InvoiceController from registrasion.controllers.invoice import InvoiceController
@ -181,33 +181,35 @@ def guided_registration(request):
attendee.save() attendee.save()
return next_step return next_step
for category in cats: with CartController.operations_batch(request.user):
products = [ for category in cats:
i for i in available_products products = [
if i.category == category i for i in available_products
] if i.category == category
]
prefix = "category_" + str(category.id) prefix = "category_" + str(category.id)
p = _handle_products(request, category, products, prefix) p = _handle_products(request, category, products, prefix)
products_form, discounts, products_handled = p products_form, discounts, products_handled = p
section = GuidedRegistrationSection( section = GuidedRegistrationSection(
title=category.name, title=category.name,
description=category.description, description=category.description,
discounts=discounts, discounts=discounts,
form=products_form, form=products_form,
) )
if products: if products:
# This product category has items to show. # This product category has items to show.
sections.append(section) sections.append(section)
# Add this to the list of things to show if the form errors. # Add this to the list of things to show if the form
request.session[SESSION_KEY].append(category.id) # errors.
request.session[SESSION_KEY].append(category.id)
if request.method == "POST" and not products_form.errors: if request.method == "POST" and not products_form.errors:
# This is only saved if we pass each form with no errors, # This is only saved if we pass each form with no
# and if the form actually has products. # errors, and if the form actually has products.
attendee.guided_categories_complete.add(category) attendee.guided_categories_complete.add(category)
if sections and request.method == "POST": if sections and request.method == "POST":
for section in sections: for section in sections:
@ -427,7 +429,15 @@ def _handle_products(request, category, products, prefix):
) )
handled = False if products_form.errors else True handled = False if products_form.errors else True
discounts = discount.available_discounts(request.user, [], products) # Making this a function to lazily evaluate when it's displayed
# in templates.
discounts = util.lazy(
DiscountController.available_discounts,
request.user,
[],
products,
)
return products_form, discounts, handled return products_form, discounts, handled
@ -435,14 +445,14 @@ def _handle_products(request, category, products, prefix):
def _set_quantities_from_products_form(products_form, current_cart): def _set_quantities_from_products_form(products_form, current_cart):
quantities = list(products_form.product_quantities()) quantities = list(products_form.product_quantities())
id_to_quantity = dict(i[:2] for i in quantities)
pks = [i[0] for i in quantities] pks = [i[0] for i in quantities]
products = inventory.Product.objects.filter( products = inventory.Product.objects.filter(
id__in=pks, id__in=pks,
).select_related("category") ).select_related("category")
product_quantities = [ product_quantities = [
(products.get(pk=i[0]), i[1]) for i in quantities (product, id_to_quantity[product.id]) for product in products
] ]
field_names = dict( field_names = dict(
(i[0][0], i[1][2]) for i in zip(product_quantities, quantities) (i[0][0], i[1][2]) for i in zip(product_quantities, quantities)

View file

@ -1,2 +1,2 @@
[flake8] [flake8]
exclude = registrasion/migrations/*, build/*, docs/* exclude = registrasion/migrations/*, build/*, docs/*, dist/*