Rearchitected condition processing such that multiple conditions are processed by the database, in bulk. Closes #42.
This commit is contained in:
parent
05269c93cd
commit
3f1be0e14e
5 changed files with 471 additions and 137 deletions
|
@ -307,6 +307,8 @@ class CartController(object):
|
|||
self._append_errors(errors, ve)
|
||||
|
||||
# Validate the discounts
|
||||
# TODO: refactor in terms of available_discounts
|
||||
# why aren't we doing that here?!
|
||||
discount_items = commerce.DiscountItem.objects.filter(cart=cart)
|
||||
seen_discounts = set()
|
||||
|
||||
|
|
|
@ -4,7 +4,12 @@ import operator
|
|||
from collections import defaultdict
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.models import Case
|
||||
from django.db.models import Count
|
||||
from django.db.models import F, Q
|
||||
from django.db.models import Sum
|
||||
from django.db.models import Value
|
||||
from django.db.models import When
|
||||
from django.utils import timezone
|
||||
|
||||
from registrasion.models import commerce
|
||||
|
@ -12,6 +17,7 @@ from registrasion.models import conditions
|
|||
from registrasion.models import inventory
|
||||
|
||||
|
||||
|
||||
ConditionAndRemainder = namedtuple(
|
||||
"ConditionAndRemainder",
|
||||
(
|
||||
|
@ -21,16 +27,77 @@ ConditionAndRemainder = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
_FlagCounter = namedtuple(
|
||||
"_FlagCounter",
|
||||
(
|
||||
"products",
|
||||
"categories",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_ConditionsCount = namedtuple(
|
||||
"ConditionsCount",
|
||||
(
|
||||
"dif",
|
||||
"eit",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FlagCounter(_FlagCounter):
|
||||
|
||||
@classmethod
|
||||
def count(cls):
|
||||
# Get the count of how many conditions should exist per product
|
||||
flagbases = conditions.FlagBase.objects
|
||||
|
||||
types = (
|
||||
conditions.FlagBase.ENABLE_IF_TRUE,
|
||||
conditions.FlagBase.DISABLE_IF_FALSE,
|
||||
)
|
||||
keys = ("eit", "dif")
|
||||
flags = [
|
||||
flagbases.filter(
|
||||
condition=condition_type
|
||||
).values(
|
||||
'products', 'categories'
|
||||
).annotate(
|
||||
count=Count('id')
|
||||
)
|
||||
for condition_type in types
|
||||
]
|
||||
|
||||
cats = defaultdict(lambda: defaultdict(int))
|
||||
prod = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
for key, flagcounts in zip(keys, flags):
|
||||
for row in flagcounts:
|
||||
if row["products"] is not None:
|
||||
prod[row["products"]][key] = row["count"]
|
||||
if row["categories"] is not None:
|
||||
cats[row["categories"]][key] = row["count"]
|
||||
|
||||
return cls(products=prod, categories=cats)
|
||||
|
||||
def get(self, product):
|
||||
p = self.products[product.id]
|
||||
c = self.categories[product.category.id]
|
||||
eit = p["eit"] + c["eit"]
|
||||
dif = p["dif"] + c["dif"]
|
||||
return _ConditionsCount(dif=dif, eit=eit)
|
||||
|
||||
|
||||
class ConditionController(object):
|
||||
''' Base class for testing conditions that activate Flag
|
||||
or Discount objects. '''
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, condition):
|
||||
self.condition = condition
|
||||
|
||||
@staticmethod
|
||||
def for_condition(condition):
|
||||
CONTROLLERS = {
|
||||
def _controllers():
|
||||
return {
|
||||
conditions.CategoryFlag: CategoryConditionController,
|
||||
conditions.IncludedProductDiscount: ProductConditionController,
|
||||
conditions.ProductFlag: ProductConditionController,
|
||||
|
@ -42,8 +109,14 @@ class ConditionController(object):
|
|||
conditions.VoucherFlag: VoucherConditionController,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def for_type(cls):
|
||||
return ConditionController._controllers()[cls]
|
||||
|
||||
@staticmethod
|
||||
def for_condition(condition):
|
||||
try:
|
||||
return CONTROLLERS[type(condition)](condition)
|
||||
return ConditionController.for_type(type(condition))(condition)
|
||||
except KeyError:
|
||||
return ConditionController()
|
||||
|
||||
|
@ -91,20 +164,9 @@ class ConditionController(object):
|
|||
products = set(products)
|
||||
quantities = {}
|
||||
|
||||
# Get the conditions covered by the products themselves
|
||||
prods = (
|
||||
product.flagbase_set.select_subclasses()
|
||||
for product in products
|
||||
)
|
||||
# Get the conditions covered by their categories
|
||||
cats = (
|
||||
category.flagbase_set.select_subclasses()
|
||||
for category in set(product.category for product in products)
|
||||
)
|
||||
|
||||
if products:
|
||||
# Simplify the query.
|
||||
all_conditions = reduce(operator.or_, itertools.chain(prods, cats))
|
||||
all_conditions = cls._filtered_flags(user, products)
|
||||
else:
|
||||
all_conditions = []
|
||||
|
||||
|
@ -114,11 +176,15 @@ class ConditionController(object):
|
|||
do_enable = defaultdict(lambda: False)
|
||||
# (if either sort of condition is present)
|
||||
|
||||
# Count the number of conditions for a product
|
||||
dif_count = defaultdict(int)
|
||||
eit_count = defaultdict(int)
|
||||
|
||||
messages = {}
|
||||
|
||||
for condition in all_conditions:
|
||||
cond = cls.for_condition(condition)
|
||||
remainder = cond.user_quantity_remaining(user)
|
||||
remainder = cond.user_quantity_remaining(user, filtered=True)
|
||||
|
||||
# Get all products covered by this condition, and the products
|
||||
# from the categories covered by this condition
|
||||
|
@ -149,14 +215,41 @@ class ConditionController(object):
|
|||
for product in all_products:
|
||||
if condition.is_disable_if_false:
|
||||
do_not_disable[product] &= met
|
||||
dif_count[product] += 1
|
||||
else:
|
||||
do_enable[product] |= met
|
||||
eit_count[product] += 1
|
||||
|
||||
if not met and product not in messages:
|
||||
messages[product] = message
|
||||
|
||||
total_flags = FlagCounter.count()
|
||||
|
||||
valid = {}
|
||||
|
||||
# the problem is that now, not every condition falls into
|
||||
# do_not_disable or do_enable '''
|
||||
# You should look into this, chris :)
|
||||
|
||||
for product in products:
|
||||
if quantities:
|
||||
if quantities[product] == 0:
|
||||
continue
|
||||
|
||||
f = total_flags.get(product)
|
||||
if f.dif > 0 and f.dif != dif_count[product]:
|
||||
do_not_disable[product] = False
|
||||
if product not in messages:
|
||||
messages[product] = "Some disable-if-false " \
|
||||
"conditions were not met"
|
||||
if f.eit > 0 and product not in do_enable:
|
||||
do_enable[product] = False
|
||||
if product not in messages:
|
||||
messages[product] = "Some enable-if-true " \
|
||||
"conditions were not met"
|
||||
|
||||
for product in itertools.chain(do_not_disable, do_enable):
|
||||
f = total_flags.get(product)
|
||||
if product in do_enable:
|
||||
# If there's an enable-if-true, we need need of those met too.
|
||||
# (do_not_disable will default to true otherwise)
|
||||
|
@ -172,7 +265,71 @@ class ConditionController(object):
|
|||
|
||||
return error_fields
|
||||
|
||||
def user_quantity_remaining(self, user):
|
||||
@classmethod
|
||||
def _filtered_flags(cls, user, products):
|
||||
'''
|
||||
|
||||
Returns:
|
||||
Sequence[flagbase]: All flags that passed the filter function.
|
||||
|
||||
'''
|
||||
|
||||
types = list(ConditionController._controllers())
|
||||
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
|
||||
|
||||
# Get all flags for the products and categories.
|
||||
prods = (
|
||||
product.flagbase_set.all()
|
||||
for product in products
|
||||
)
|
||||
cats = (
|
||||
category.flagbase_set.all()
|
||||
for category in set(product.category for product in products)
|
||||
)
|
||||
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
|
||||
|
||||
all_subsets = []
|
||||
|
||||
for flagtype in flagtypes:
|
||||
flags = flagtype.objects.filter(id__in=all_flags)
|
||||
ctrl = ConditionController.for_type(flagtype)
|
||||
flags = ctrl.pre_filter(flags, user)
|
||||
all_subsets.append(flags)
|
||||
|
||||
return itertools.chain(*all_subsets)
|
||||
|
||||
@classmethod
|
||||
def pre_filter(cls, queryset, user):
|
||||
''' Returns only the flag conditions that might be available for this
|
||||
user. It should hopefully reduce the number of queries that need to be
|
||||
executed to determine if a flag is met.
|
||||
|
||||
If this filtration implements the same query as is_met, then you should
|
||||
be able to implement ``is_met()`` in terms of this.
|
||||
|
||||
Arguments:
|
||||
|
||||
queryset (Queryset[c]): The canditate conditions.
|
||||
|
||||
user (User): The user for whom we're testing these conditions.
|
||||
|
||||
Returns:
|
||||
Queryset[c]: A subset of the conditions that pass the pre-filter
|
||||
test for this user.
|
||||
|
||||
'''
|
||||
|
||||
# Default implementation does NOTHING.
|
||||
return queryset
|
||||
|
||||
def passes_filter(self, user):
|
||||
''' Returns true if the condition passes the filter '''
|
||||
|
||||
cls = type(self.condition)
|
||||
qs = cls.objects.filter(pk=self.condition.id)
|
||||
return self.condition in self.pre_filter(qs, user)
|
||||
|
||||
def user_quantity_remaining(self, user, filtered=False):
|
||||
''' Returns the number of items covered by this flag condition the
|
||||
user can add to the current cart. This default implementation returns
|
||||
a big number if is_met() is true, otherwise 0.
|
||||
|
@ -180,26 +337,37 @@ class ConditionController(object):
|
|||
Either this method, or is_met() must be overridden in subclasses.
|
||||
'''
|
||||
|
||||
return 99999999 if self.is_met(user) else 0
|
||||
return _BIG_QUANTITY if self.is_met(user, filtered) else 0
|
||||
|
||||
def is_met(self, user):
|
||||
def is_met(self, user, filtered=False):
|
||||
''' Returns True if this flag condition is met, otherwise returns
|
||||
False.
|
||||
|
||||
Either this method, or user_quantity_remaining() must be overridden
|
||||
in subclasses.
|
||||
|
||||
Arguments:
|
||||
|
||||
user (User): The user for whom this test must be met.
|
||||
|
||||
filter (bool): If true, this condition was part of a queryset
|
||||
returned by pre_filter() for this user.
|
||||
|
||||
'''
|
||||
return self.user_quantity_remaining(user) > 0
|
||||
return self.user_quantity_remaining(user, filtered) > 0
|
||||
|
||||
|
||||
class CategoryConditionController(ConditionController):
|
||||
class IsMetByFilter(object):
|
||||
|
||||
def __init__(self, condition):
|
||||
self.condition = condition
|
||||
def is_met(self, user, filtered=False):
|
||||
''' Returns True if this flag condition is met, otherwise returns
|
||||
False. It determines if the condition is met by calling pre_filter
|
||||
with a queryset containing only self.condition. '''
|
||||
|
||||
def is_met(self, user):
|
||||
''' returns True if the user has a product from a category that invokes
|
||||
this condition in one of their carts '''
|
||||
if filtered:
|
||||
return True # Why query again?
|
||||
|
||||
return self.passes_filter(user)
|
||||
|
||||
carts = commerce.Cart.objects.filter(user=user)
|
||||
carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
|
||||
|
@ -212,112 +380,176 @@ class CategoryConditionController(ConditionController):
|
|||
).count()
|
||||
return products_count > 0
|
||||
|
||||
class RemainderSetByFilter(object):
|
||||
|
||||
class ProductConditionController(ConditionController):
|
||||
def user_quantity_remaining(self, user, filtered=True):
|
||||
''' returns 0 if the date range is violated, otherwise, it will return
|
||||
the quantity remaining under the stock limit.
|
||||
|
||||
The filter for this condition must add an annotation called "remainder"
|
||||
in order for this to work.
|
||||
'''
|
||||
|
||||
if filtered:
|
||||
if hasattr(self.condition, "remainder"):
|
||||
return self.condition.remainder
|
||||
|
||||
|
||||
|
||||
# Mark self.condition with a remainder
|
||||
qs = type(self.condition).objects.filter(pk=self.condition.id)
|
||||
qs = self.pre_filter(qs, user)
|
||||
|
||||
if len(qs) > 0:
|
||||
return qs[0].remainder
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
class CategoryConditionController(IsMetByFilter, ConditionController):
|
||||
|
||||
@classmethod
|
||||
def pre_filter(self, queryset, user):
|
||||
''' Returns all of the items from queryset where the user has a
|
||||
product from a category invoking that item's condition in one of their
|
||||
carts. '''
|
||||
|
||||
items = commerce.ProductItem.objects.filter(cart__user=user)
|
||||
items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED)
|
||||
items = items.select_related("product", "product__category")
|
||||
categories = [item.product.category for item in items]
|
||||
|
||||
return queryset.filter(enabling_category__in=categories)
|
||||
|
||||
|
||||
class ProductConditionController(IsMetByFilter, ConditionController):
|
||||
''' Condition tests for ProductFlag and
|
||||
IncludedProductDiscount. '''
|
||||
|
||||
def __init__(self, condition):
|
||||
self.condition = condition
|
||||
@classmethod
|
||||
def pre_filter(self, queryset, user):
|
||||
''' Returns all of the items from queryset where the user has a
|
||||
product invoking that item's condition in one of their carts. '''
|
||||
|
||||
def is_met(self, user):
|
||||
''' returns True if the user has a product that invokes this
|
||||
condition in one of their carts '''
|
||||
items = commerce.ProductItem.objects.filter(cart__user=user)
|
||||
items = items.exclude(cart__status=commerce.Cart.STATUS_RELEASED)
|
||||
items = items.select_related("product", "product__category")
|
||||
products = [item.product for item in items]
|
||||
|
||||
carts = commerce.Cart.objects.filter(user=user)
|
||||
carts = carts.exclude(status=commerce.Cart.STATUS_RELEASED)
|
||||
products_count = commerce.ProductItem.objects.filter(
|
||||
cart__in=carts,
|
||||
product__in=self.condition.enabling_products.all(),
|
||||
).count()
|
||||
return products_count > 0
|
||||
return queryset.filter(enabling_products__in=products)
|
||||
|
||||
|
||||
class TimeOrStockLimitConditionController(ConditionController):
|
||||
class TimeOrStockLimitConditionController(
|
||||
RemainderSetByFilter,
|
||||
ConditionController,
|
||||
):
|
||||
''' Common condition tests for TimeOrStockLimit Flag and
|
||||
Discount.'''
|
||||
|
||||
def __init__(self, ceiling):
|
||||
self.ceiling = ceiling
|
||||
@classmethod
|
||||
def pre_filter(self, queryset, user):
|
||||
''' Returns all of the items from queryset where the date falls into
|
||||
any specified range, but not yet where the stock limit is not yet
|
||||
reached.'''
|
||||
|
||||
def user_quantity_remaining(self, user):
|
||||
''' returns 0 if the date range is violated, otherwise, it will return
|
||||
the quantity remaining under the stock limit. '''
|
||||
|
||||
# Test date range
|
||||
if not self._test_date_range():
|
||||
return 0
|
||||
|
||||
return self._get_remaining_stock(user)
|
||||
|
||||
def _test_date_range(self):
|
||||
now = timezone.now()
|
||||
|
||||
if self.ceiling.start_time is not None:
|
||||
if now < self.ceiling.start_time:
|
||||
return False
|
||||
# Keep items with no start time, or start time not yet met.
|
||||
queryset = queryset.filter(Q(start_time=None) | Q(start_time__lte=now))
|
||||
queryset = queryset.filter(Q(end_time=None) | Q(end_time__gte=now))
|
||||
|
||||
if self.ceiling.end_time is not None:
|
||||
if now > self.ceiling.end_time:
|
||||
return False
|
||||
# Filter out items that have been reserved beyond the limits
|
||||
quantity_or_zero = self._calculate_quantities(user)
|
||||
|
||||
return True
|
||||
remainder = Case(
|
||||
When(limit=None, then=Value(_BIG_QUANTITY)),
|
||||
default=F("limit") - Sum(quantity_or_zero),
|
||||
)
|
||||
|
||||
def _get_remaining_stock(self, user):
|
||||
''' Returns the stock that remains under this ceiling, excluding the
|
||||
user's current cart. '''
|
||||
queryset = queryset.annotate(remainder=remainder)
|
||||
queryset = queryset.filter(remainder__gt=0)
|
||||
|
||||
if self.ceiling.limit is None:
|
||||
return 99999999
|
||||
return queryset
|
||||
|
||||
# We care about all reserved carts, but not the user's current cart
|
||||
@classmethod
|
||||
def _relevant_carts(cls, user):
|
||||
reserved_carts = commerce.Cart.reserved_carts()
|
||||
reserved_carts = reserved_carts.exclude(
|
||||
user=user,
|
||||
status=commerce.Cart.STATUS_ACTIVE,
|
||||
)
|
||||
|
||||
items = self._items()
|
||||
items = items.filter(cart__in=reserved_carts)
|
||||
count = items.aggregate(Sum("quantity"))["quantity__sum"] or 0
|
||||
|
||||
return self.ceiling.limit - count
|
||||
return reserved_carts
|
||||
|
||||
|
||||
class TimeOrStockLimitFlagController(
|
||||
TimeOrStockLimitConditionController):
|
||||
|
||||
def _items(self):
|
||||
category_products = inventory.Product.objects.filter(
|
||||
category__in=self.ceiling.categories.all(),
|
||||
)
|
||||
products = self.ceiling.products.all() | category_products
|
||||
@classmethod
|
||||
def _calculate_quantities(cls, user):
|
||||
reserved_carts = cls._relevant_carts(user)
|
||||
|
||||
product_items = commerce.ProductItem.objects.filter(
|
||||
product__in=products.all(),
|
||||
# Calculate category lines
|
||||
cat_items = F('categories__product__productitem__product__category')
|
||||
reserved_category_products = (
|
||||
Q(categories=cat_items) &
|
||||
Q(categories__product__productitem__cart__in=reserved_carts)
|
||||
)
|
||||
return product_items
|
||||
|
||||
# Calculate product lines
|
||||
reserved_products = (
|
||||
Q(products=F('products__productitem__product')) &
|
||||
Q(products__productitem__cart__in=reserved_carts)
|
||||
)
|
||||
|
||||
category_quantity_in_reserved_carts = When(
|
||||
reserved_category_products,
|
||||
then="categories__product__productitem__quantity",
|
||||
)
|
||||
|
||||
product_quantity_in_reserved_carts = When(
|
||||
reserved_products,
|
||||
then="products__productitem__quantity",
|
||||
)
|
||||
|
||||
quantity_or_zero = Case(
|
||||
category_quantity_in_reserved_carts,
|
||||
product_quantity_in_reserved_carts,
|
||||
default=Value(0),
|
||||
)
|
||||
|
||||
return quantity_or_zero
|
||||
|
||||
|
||||
class TimeOrStockLimitDiscountController(TimeOrStockLimitConditionController):
|
||||
|
||||
def _items(self):
|
||||
discount_items = commerce.DiscountItem.objects.filter(
|
||||
discount=self.ceiling,
|
||||
@classmethod
|
||||
def _calculate_quantities(cls, user):
|
||||
reserved_carts = cls._relevant_carts(user)
|
||||
|
||||
quantity_in_reserved_carts = When(
|
||||
discountitem__cart__in=reserved_carts,
|
||||
then="discountitem__quantity"
|
||||
)
|
||||
return discount_items
|
||||
|
||||
quantity_or_zero = Case(
|
||||
quantity_in_reserved_carts,
|
||||
default=Value(0)
|
||||
)
|
||||
|
||||
return quantity_or_zero
|
||||
|
||||
|
||||
class VoucherConditionController(ConditionController):
|
||||
class VoucherConditionController(IsMetByFilter, ConditionController):
|
||||
''' Condition test for VoucherFlag and VoucherDiscount.'''
|
||||
|
||||
def __init__(self, condition):
|
||||
self.condition = condition
|
||||
@classmethod
|
||||
def pre_filter(self, queryset, user):
|
||||
''' Returns all of the items from queryset where the user has entered
|
||||
a voucher that invokes that item's condition in one of their carts. '''
|
||||
|
||||
def is_met(self, user):
|
||||
''' returns True if the user has the given voucher attached. '''
|
||||
carts_count = commerce.Cart.objects.filter(
|
||||
carts = commerce.Cart.objects.filter(
|
||||
user=user,
|
||||
vouchers=self.condition.voucher,
|
||||
).count()
|
||||
return carts_count > 0
|
||||
)
|
||||
vouchers = [cart.vouchers.all() for cart in carts]
|
||||
|
||||
return queryset.filter(voucher__in=itertools.chain(*vouchers))
|
||||
|
|
|
@ -4,7 +4,11 @@ from conditions import ConditionController
|
|||
from registrasion.models import commerce
|
||||
from registrasion.models import conditions
|
||||
|
||||
from django.db.models import Case
|
||||
from django.db.models import Q
|
||||
from django.db.models import Sum
|
||||
from django.db.models import Value
|
||||
from django.db.models import When
|
||||
|
||||
|
||||
class DiscountAndQuantity(object):
|
||||
|
@ -43,6 +47,62 @@ def available_discounts(user, categories, products):
|
|||
and products. The discounts also list the available quantity for this user,
|
||||
not including products that are pending purchase. '''
|
||||
|
||||
|
||||
|
||||
filtered_clauses = _filtered_discounts(user, categories, products)
|
||||
|
||||
discounts = []
|
||||
|
||||
# Markers so that we don't need to evaluate given conditions more than once
|
||||
accepted_discounts = set()
|
||||
failed_discounts = set()
|
||||
|
||||
for clause in filtered_clauses:
|
||||
discount = clause.discount
|
||||
cond = ConditionController.for_condition(discount)
|
||||
|
||||
past_use_count = discount.past_use_count
|
||||
|
||||
# TODO: add test case --
|
||||
# discount covers 2x prod_1 and 1x prod_2
|
||||
# add 1x prod_2
|
||||
# add 1x prod_1
|
||||
# checkout
|
||||
# discount should be available for prod_1
|
||||
|
||||
if past_use_count >= clause.quantity:
|
||||
# This clause has exceeded its use count
|
||||
pass
|
||||
elif discount not in failed_discounts:
|
||||
# This clause is still available
|
||||
is_accepted = discount in accepted_discounts
|
||||
if is_accepted or cond.is_met(user, filtered=True):
|
||||
# This clause is valid for this user
|
||||
discounts.append(DiscountAndQuantity(
|
||||
discount=discount,
|
||||
clause=clause,
|
||||
quantity=clause.quantity - past_use_count,
|
||||
))
|
||||
accepted_discounts.add(discount)
|
||||
else:
|
||||
# This clause is not valid for this user
|
||||
failed_discounts.add(discount)
|
||||
return discounts
|
||||
|
||||
|
||||
def _filtered_discounts(user, categories, products):
|
||||
'''
|
||||
|
||||
Returns:
|
||||
Sequence[discountbase]: All discounts that passed the filter function.
|
||||
|
||||
'''
|
||||
|
||||
types = list(ConditionController._controllers())
|
||||
discounttypes = [
|
||||
i for i in types if issubclass(i, conditions.DiscountBase)
|
||||
]
|
||||
|
||||
# discounts that match provided categories
|
||||
category_discounts = conditions.DiscountForCategory.objects.filter(
|
||||
category__in=categories
|
||||
|
@ -67,51 +127,56 @@ def available_discounts(user, categories, products):
|
|||
"category",
|
||||
)
|
||||
|
||||
valid_discounts = conditions.DiscountBase.objects.filter(
|
||||
Q(discountforproduct__in=product_discounts) |
|
||||
Q(discountforcategory__in=all_category_discounts)
|
||||
)
|
||||
|
||||
all_subsets = []
|
||||
|
||||
for discounttype in discounttypes:
|
||||
discounts = discounttype.objects.filter(id__in=valid_discounts)
|
||||
ctrl = ConditionController.for_type(discounttype)
|
||||
discounts = ctrl.pre_filter(discounts, user)
|
||||
discounts = _annotate_with_past_uses(discounts, user)
|
||||
all_subsets.append(discounts)
|
||||
|
||||
filtered_discounts = list(itertools.chain(*all_subsets))
|
||||
|
||||
# Map from discount key to itself (contains annotations added by filter)
|
||||
from_filter = dict((i.id, i) for i in filtered_discounts)
|
||||
|
||||
# The set of all potential discounts
|
||||
potential_discounts = set(itertools.chain(
|
||||
product_discounts,
|
||||
all_category_discounts,
|
||||
discount_clauses = set(itertools.chain(
|
||||
product_discounts.filter(discount__in=filtered_discounts),
|
||||
all_category_discounts.filter(discount__in=filtered_discounts),
|
||||
))
|
||||
|
||||
discounts = []
|
||||
# Replace discounts with the filtered ones
|
||||
# These are the correct subclasses (saves query later on), and have
|
||||
# correct annotations from filters if necessary.
|
||||
for clause in discount_clauses:
|
||||
clause.discount = from_filter[clause.discount.id]
|
||||
|
||||
# Markers so that we don't need to evaluate given conditions more than once
|
||||
accepted_discounts = set()
|
||||
failed_discounts = set()
|
||||
return discount_clauses
|
||||
|
||||
for discount in potential_discounts:
|
||||
real_discount = conditions.DiscountBase.objects.get_subclass(
|
||||
pk=discount.discount.pk,
|
||||
)
|
||||
cond = ConditionController.for_condition(real_discount)
|
||||
|
||||
# Count the past uses of the given discount item.
|
||||
# If this user has exceeded the limit for the clause, this clause
|
||||
# is not available any more.
|
||||
past_uses = commerce.DiscountItem.objects.filter(
|
||||
cart__user=user,
|
||||
cart__status=commerce.Cart.STATUS_PAID, # Only past carts count
|
||||
discount=real_discount,
|
||||
)
|
||||
agg = past_uses.aggregate(Sum("quantity"))
|
||||
past_use_count = agg["quantity__sum"]
|
||||
if past_use_count is None:
|
||||
past_use_count = 0
|
||||
def _annotate_with_past_uses(queryset, user):
|
||||
''' Annotates the queryset with a usage count for that discount by the
|
||||
given user. '''
|
||||
|
||||
if past_use_count >= discount.quantity:
|
||||
# This clause has exceeded its use count
|
||||
pass
|
||||
elif real_discount not in failed_discounts:
|
||||
# This clause is still available
|
||||
if real_discount in accepted_discounts or cond.is_met(user):
|
||||
# This clause is valid for this user
|
||||
discounts.append(DiscountAndQuantity(
|
||||
discount=real_discount,
|
||||
clause=discount,
|
||||
quantity=discount.quantity - past_use_count,
|
||||
))
|
||||
accepted_discounts.add(real_discount)
|
||||
else:
|
||||
# This clause is not valid for this user
|
||||
failed_discounts.add(real_discount)
|
||||
return discounts
|
||||
past_use_quantity = When(
|
||||
(
|
||||
Q(discountitem__cart__user=user) &
|
||||
Q(discountitem__cart__status=commerce.Cart.STATUS_PAID)
|
||||
),
|
||||
then="discountitem__quantity",
|
||||
)
|
||||
|
||||
past_use_quantity_or_zero = Case(
|
||||
past_use_quantity,
|
||||
default=Value(0),
|
||||
)
|
||||
|
||||
queryset = queryset.annotate(past_use_count=Sum(past_use_quantity_or_zero))
|
||||
return queryset
|
||||
|
|
|
@ -6,6 +6,8 @@ from django.core.exceptions import ValidationError
|
|||
from controller_helpers import TestingCartController
|
||||
from test_cart import RegistrationCartTestCase
|
||||
|
||||
from registrasion.controllers.discount import available_discounts
|
||||
from registrasion.controllers.product import ProductController
|
||||
from registrasion.models import commerce
|
||||
from registrasion.models import conditions
|
||||
|
||||
|
@ -135,6 +137,39 @@ class CeilingsTestCases(RegistrationCartTestCase):
|
|||
with self.assertRaises(ValidationError):
|
||||
first_cart.validate_cart()
|
||||
|
||||
def test_discount_ceiling_aggregates_products(self):
|
||||
# Create two carts, add 1xprod_1 to each. Ceiling should disappear
|
||||
# after second.
|
||||
self.make_discount_ceiling(
|
||||
"Multi-product limit discount ceiling",
|
||||
limit=2,
|
||||
)
|
||||
for i in xrange(2):
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
cart.add_to_cart(self.PROD_1, 1)
|
||||
cart.next_cart()
|
||||
|
||||
discounts = available_discounts(self.USER_1, [], [self.PROD_1])
|
||||
|
||||
self.assertEqual(0, len(discounts))
|
||||
|
||||
def test_flag_ceiling_aggregates_products(self):
|
||||
# Create two carts, add 1xprod_1 to each. Ceiling should disappear
|
||||
# after second.
|
||||
self.make_ceiling("Multi-product limit ceiling", limit=2)
|
||||
|
||||
for i in xrange(2):
|
||||
cart = TestingCartController.for_user(self.USER_1)
|
||||
cart.add_to_cart(self.PROD_1, 1)
|
||||
cart.next_cart()
|
||||
|
||||
products = ProductController.available_products(
|
||||
self.USER_1,
|
||||
products=[self.PROD_1],
|
||||
)
|
||||
|
||||
self.assertEqual(0, len(products))
|
||||
|
||||
def test_items_released_from_ceiling_by_refund(self):
|
||||
self.make_ceiling("Limit ceiling", limit=1)
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
[flake8]
|
||||
exclude = registrasion/migrations/*, build/*, docs/*
|
||||
exclude = registrasion/migrations/*, build/*, docs/*, dist/*
|
||||
|
|
Loading…
Reference in a new issue