Merge branch 'batch_cache'
This commit is contained in:
commit
ded5114073
9 changed files with 430 additions and 179 deletions
119
registrasion/controllers/batch.py
Normal file
119
registrasion/controllers/batch.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from django.contrib.auth.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class BatchController(object):
|
||||||
|
''' Batches are sets of operations where certain queries for users may be
|
||||||
|
repeated, but are also unlikely change within the boundaries of the batch.
|
||||||
|
|
||||||
|
Batches are keyed per-user. You can mark the edge of the batch with the
|
||||||
|
``batch`` context manager. If you nest calls to ``batch``, only the
|
||||||
|
outermost call will have the effect of ending the batch.
|
||||||
|
|
||||||
|
Batches store results for functions wrapped with ``memoise``. These results
|
||||||
|
for the user are flushed at the end of the batch.
|
||||||
|
|
||||||
|
If a return for a memoised function has a callable attribute called
|
||||||
|
``end_batch``, that attribute will be called at the end of the batch.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
_user_caches = {}
|
||||||
|
_NESTING_KEY = "nesting_count"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def batch(cls, user):
|
||||||
|
''' Marks the entry point for a batch for the given user. '''
|
||||||
|
|
||||||
|
cls._enter_batch_context(user)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Make sure we clean up in case of errors.
|
||||||
|
cls._exit_batch_context(user)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _enter_batch_context(cls, user):
|
||||||
|
if user not in cls._user_caches:
|
||||||
|
cls._user_caches[user] = cls._new_cache()
|
||||||
|
|
||||||
|
cache = cls._user_caches[user]
|
||||||
|
cache[cls._NESTING_KEY] += 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _exit_batch_context(cls, user):
|
||||||
|
cache = cls._user_caches[user]
|
||||||
|
cache[cls._NESTING_KEY] -= 1
|
||||||
|
|
||||||
|
if cache[cls._NESTING_KEY] == 0:
|
||||||
|
cls._call_end_batch_methods(user)
|
||||||
|
del cls._user_caches[user]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _call_end_batch_methods(cls, user):
|
||||||
|
cache = cls._user_caches[user]
|
||||||
|
ended = set()
|
||||||
|
while True:
|
||||||
|
keys = set(cache.keys())
|
||||||
|
if ended == keys:
|
||||||
|
break
|
||||||
|
keys_to_end = keys - ended
|
||||||
|
for key in keys_to_end:
|
||||||
|
item = cache[key]
|
||||||
|
if hasattr(item, 'end_batch') and callable(item.end_batch):
|
||||||
|
item.end_batch()
|
||||||
|
ended = ended | keys_to_end
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def memoise(cls, func):
|
||||||
|
''' Decorator that stores the result of the stored function in the
|
||||||
|
user's results cache until the batch completes. Keyword arguments are
|
||||||
|
not yet supported.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
func (callable(*a)): The function whose results we want
|
||||||
|
to store. The positional arguments, ``a``, are used as cache
|
||||||
|
keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable(*a): The memosing version of ``func``.
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def f(*a):
|
||||||
|
|
||||||
|
for arg in a:
|
||||||
|
if isinstance(arg, User):
|
||||||
|
user = arg
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("One position argument must be a User")
|
||||||
|
|
||||||
|
func_key = (func, tuple(a))
|
||||||
|
cache = cls.get_cache(user)
|
||||||
|
|
||||||
|
if func_key not in cache:
|
||||||
|
cache[func_key] = func(*a)
|
||||||
|
|
||||||
|
return cache[func_key]
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache(cls, user):
|
||||||
|
if user not in cls._user_caches:
|
||||||
|
# Return blank cache here, we'll just discard :)
|
||||||
|
return cls._new_cache()
|
||||||
|
|
||||||
|
return cls._user_caches[user]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _new_cache(cls):
|
||||||
|
''' Returns a new cache dictionary. '''
|
||||||
|
cache = {}
|
||||||
|
cache[cls._NESTING_KEY] = 0
|
||||||
|
return cache
|
|
@ -16,6 +16,7 @@ from registrasion.models import commerce
|
||||||
from registrasion.models import conditions
|
from registrasion.models import conditions
|
||||||
from registrasion.models import inventory
|
from registrasion.models import inventory
|
||||||
|
|
||||||
|
from.batch import BatchController
|
||||||
from .category import CategoryController
|
from .category import CategoryController
|
||||||
from .discount import DiscountController
|
from .discount import DiscountController
|
||||||
from .flag import FlagController
|
from .flag import FlagController
|
||||||
|
@ -34,10 +35,11 @@ def _modifies_cart(func):
|
||||||
def inner(self, *a, **k):
|
def inner(self, *a, **k):
|
||||||
self._fail_if_cart_is_not_active()
|
self._fail_if_cart_is_not_active()
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
with CartController.operations_batch(self.cart.user) as mark:
|
with BatchController.batch(self.cart.user):
|
||||||
mark.mark = True # Marker that we've modified the cart
|
# Mark the version of self in the batch cache as modified
|
||||||
|
memoised = self.for_user(self.cart.user)
|
||||||
|
memoised._modified_by_batch = True
|
||||||
return func(self, *a, **k)
|
return func(self, *a, **k)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,6 +49,7 @@ class CartController(object):
|
||||||
self.cart = cart
|
self.cart = cart
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@BatchController.memoise
|
||||||
def for_user(cls, user):
|
def for_user(cls, user):
|
||||||
''' Returns the user's current cart, or creates a new cart
|
''' Returns the user's current cart, or creates a new cart
|
||||||
if there isn't one ready yet. '''
|
if there isn't one ready yet. '''
|
||||||
|
@ -64,59 +67,6 @@ class CartController(object):
|
||||||
)
|
)
|
||||||
return cls(existing)
|
return cls(existing)
|
||||||
|
|
||||||
# Marks the carts that are currently in batches
|
|
||||||
_FOR_USER = {}
|
|
||||||
_BATCH_COUNT = collections.defaultdict(int)
|
|
||||||
_MODIFIED_CARTS = set()
|
|
||||||
|
|
||||||
class _ModificationMarker(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def operations_batch(cls, user):
|
|
||||||
''' Marks the boundary for a batch of operations on a user's cart.
|
|
||||||
|
|
||||||
These markers can be nested. Only on exiting the outermost marker will
|
|
||||||
a batch be ended.
|
|
||||||
|
|
||||||
When a batch is ended, discounts are recalculated, and the cart's
|
|
||||||
revision is increased.
|
|
||||||
'''
|
|
||||||
|
|
||||||
if user not in cls._FOR_USER:
|
|
||||||
_ctrl = cls.for_user(user)
|
|
||||||
cls._FOR_USER[user] = (_ctrl, _ctrl.cart.id)
|
|
||||||
|
|
||||||
ctrl, _id = cls._FOR_USER[user]
|
|
||||||
|
|
||||||
cls._BATCH_COUNT[_id] += 1
|
|
||||||
try:
|
|
||||||
success = False
|
|
||||||
|
|
||||||
marker = cls._ModificationMarker()
|
|
||||||
yield marker
|
|
||||||
|
|
||||||
if hasattr(marker, "mark"):
|
|
||||||
cls._MODIFIED_CARTS.add(_id)
|
|
||||||
|
|
||||||
success = True
|
|
||||||
finally:
|
|
||||||
|
|
||||||
cls._BATCH_COUNT[_id] -= 1
|
|
||||||
|
|
||||||
# Only end on the outermost batch marker, and only if
|
|
||||||
# it excited cleanly, and a modification occurred
|
|
||||||
modified = _id in cls._MODIFIED_CARTS
|
|
||||||
outermost = cls._BATCH_COUNT[_id] == 0
|
|
||||||
if modified and outermost and success:
|
|
||||||
ctrl._end_batch()
|
|
||||||
cls._MODIFIED_CARTS.remove(_id)
|
|
||||||
|
|
||||||
# Clear out the cache on the outermost operation
|
|
||||||
if outermost:
|
|
||||||
del cls._FOR_USER[user]
|
|
||||||
|
|
||||||
def _fail_if_cart_is_not_active(self):
|
def _fail_if_cart_is_not_active(self):
|
||||||
self.cart.refresh_from_db()
|
self.cart.refresh_from_db()
|
||||||
if self.cart.status != commerce.Cart.STATUS_ACTIVE:
|
if self.cart.status != commerce.Cart.STATUS_ACTIVE:
|
||||||
|
@ -144,6 +94,13 @@ class CartController(object):
|
||||||
self.cart.time_last_updated = timezone.now()
|
self.cart.time_last_updated = timezone.now()
|
||||||
self.cart.reservation_duration = max(reservations)
|
self.cart.reservation_duration = max(reservations)
|
||||||
|
|
||||||
|
|
||||||
|
def end_batch(self):
|
||||||
|
''' Calls ``_end_batch`` if a modification has been performed in the
|
||||||
|
previous batch. '''
|
||||||
|
if hasattr(self,'_modified_by_batch'):
|
||||||
|
self._end_batch()
|
||||||
|
|
||||||
def _end_batch(self):
|
def _end_batch(self):
|
||||||
''' Performs operations that occur occur at the end of a batch of
|
''' Performs operations that occur occur at the end of a batch of
|
||||||
product changes/voucher applications etc.
|
product changes/voucher applications etc.
|
||||||
|
@ -217,16 +174,14 @@ class CartController(object):
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Pre-annotate products
|
# Pre-annotate products
|
||||||
products = [p for (p, q) in product_quantities]
|
remainders = ProductController.user_remainders(self.cart.user)
|
||||||
r = ProductController.attach_user_remainders(self.cart.user, products)
|
|
||||||
with_remainders = dict((p, p) for p in r)
|
|
||||||
|
|
||||||
# Test each product limit here
|
# Test each product limit here
|
||||||
for product, quantity in product_quantities:
|
for product, quantity in product_quantities:
|
||||||
if quantity < 0:
|
if quantity < 0:
|
||||||
errors.append((product, "Value must be zero or greater."))
|
errors.append((product, "Value must be zero or greater."))
|
||||||
|
|
||||||
limit = with_remainders[product].remainder
|
limit = remainders[product.id]
|
||||||
|
|
||||||
if quantity > limit:
|
if quantity > limit:
|
||||||
errors.append((
|
errors.append((
|
||||||
|
@ -242,12 +197,11 @@ class CartController(object):
|
||||||
by_cat[product.category].append((product, quantity))
|
by_cat[product.category].append((product, quantity))
|
||||||
|
|
||||||
# Pre-annotate categories
|
# Pre-annotate categories
|
||||||
r = CategoryController.attach_user_remainders(self.cart.user, by_cat)
|
remainders = CategoryController.user_remainders(self.cart.user)
|
||||||
with_remainders = dict((cat, cat) for cat in r)
|
|
||||||
|
|
||||||
# Test each category limit here
|
# Test each category limit here
|
||||||
for category in by_cat:
|
for category in by_cat:
|
||||||
limit = with_remainders[category].remainder
|
limit = remainders[category.id]
|
||||||
|
|
||||||
# Get the amount so far in the cart
|
# Get the amount so far in the cart
|
||||||
to_add = sum(i[1] for i in by_cat[category])
|
to_add = sum(i[1] for i in by_cat[category])
|
||||||
|
|
|
@ -7,6 +7,7 @@ from django.db.models import Sum
|
||||||
from django.db.models import When
|
from django.db.models import When
|
||||||
from django.db.models import Value
|
from django.db.models import Value
|
||||||
|
|
||||||
|
from .batch import BatchController
|
||||||
|
|
||||||
class AllProducts(object):
|
class AllProducts(object):
|
||||||
pass
|
pass
|
||||||
|
@ -39,17 +40,17 @@ class CategoryController(object):
|
||||||
return set(i.category for i in available)
|
return set(i.category for i in available)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def attach_user_remainders(cls, user, categories):
|
@BatchController.memoise
|
||||||
|
def user_remainders(cls, user):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
queryset(inventory.Product): A queryset containing items from
|
Mapping[int->int]: A dictionary that maps the category ID to the
|
||||||
``categories``, with an extra attribute -- remainder = the amount
|
user's remainder for that category.
|
||||||
of items from this category that is remaining.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
ids = [category.id for category in categories]
|
categories = inventory.Category.objects.all()
|
||||||
categories = inventory.Category.objects.filter(id__in=ids)
|
|
||||||
|
|
||||||
cart_filter = (
|
cart_filter = (
|
||||||
Q(product__productitem__cart__user=user) &
|
Q(product__productitem__cart__user=user) &
|
||||||
|
@ -73,12 +74,4 @@ class CategoryController(object):
|
||||||
|
|
||||||
categories = categories.annotate(remainder=remainder)
|
categories = categories.annotate(remainder=remainder)
|
||||||
|
|
||||||
return categories
|
return dict((cat.id, cat.remainder) for cat in categories)
|
||||||
|
|
||||||
def user_quantity_remaining(self, user):
|
|
||||||
''' Returns the quantity of this product that the user add in the
|
|
||||||
current cart. '''
|
|
||||||
|
|
||||||
with_remainders = self.attach_user_remainders(user, [self.category])
|
|
||||||
|
|
||||||
return with_remainders[0].remainder
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from conditions import ConditionController
|
from .batch import BatchController
|
||||||
|
from .conditions import ConditionController
|
||||||
|
|
||||||
from registrasion.models import commerce
|
from registrasion.models import commerce
|
||||||
from registrasion.models import conditions
|
from registrasion.models import conditions
|
||||||
|
|
||||||
|
@ -10,7 +12,6 @@ from django.db.models import Sum
|
||||||
from django.db.models import Value
|
from django.db.models import Value
|
||||||
from django.db.models import When
|
from django.db.models import When
|
||||||
|
|
||||||
|
|
||||||
class DiscountAndQuantity(object):
|
class DiscountAndQuantity(object):
|
||||||
''' Represents a discount that can be applied to a product or category
|
''' Represents a discount that can be applied to a product or category
|
||||||
for a given user.
|
for a given user.
|
||||||
|
@ -50,7 +51,22 @@ class DiscountController(object):
|
||||||
categories and products. The discounts also list the available quantity
|
categories and products. The discounts also list the available quantity
|
||||||
for this user, not including products that are pending purchase. '''
|
for this user, not including products that are pending purchase. '''
|
||||||
|
|
||||||
filtered_clauses = cls._filtered_discounts(user, categories, products)
|
filtered_clauses = cls._filtered_clauses(user)
|
||||||
|
|
||||||
|
# clauses that match provided categories
|
||||||
|
categories = set(categories)
|
||||||
|
# clauses that match provided products
|
||||||
|
products = set(products)
|
||||||
|
# clauses that match categories for provided products
|
||||||
|
product_categories = set(product.category for product in products)
|
||||||
|
# (Not relevant: clauses that match products in provided categories)
|
||||||
|
all_categories = categories | product_categories
|
||||||
|
|
||||||
|
filtered_clauses = (
|
||||||
|
clause for clause in filtered_clauses
|
||||||
|
if hasattr(clause, 'product') and clause.product in products or
|
||||||
|
hasattr(clause, 'category') and clause.category in all_categories
|
||||||
|
)
|
||||||
|
|
||||||
discounts = []
|
discounts = []
|
||||||
|
|
||||||
|
@ -84,12 +100,13 @@ class DiscountController(object):
|
||||||
return discounts
|
return discounts
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _filtered_discounts(cls, user, categories, products):
|
@BatchController.memoise
|
||||||
|
def _filtered_clauses(cls, user):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sequence[discountbase]: All discounts that passed the filter
|
Sequence[DiscountForProduct | DiscountForCategory]: All clauses
|
||||||
function.
|
that passed the filter function.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@ -98,42 +115,22 @@ class DiscountController(object):
|
||||||
i for i in types if issubclass(i, conditions.DiscountBase)
|
i for i in types if issubclass(i, conditions.DiscountBase)
|
||||||
]
|
]
|
||||||
|
|
||||||
# discounts that match provided categories
|
product_clauses = conditions.DiscountForProduct.objects.all()
|
||||||
category_discounts = conditions.DiscountForCategory.objects.filter(
|
product_clauses = product_clauses.select_related(
|
||||||
category__in=categories
|
"discount",
|
||||||
)
|
|
||||||
# discounts that match provided products
|
|
||||||
product_discounts = conditions.DiscountForProduct.objects.filter(
|
|
||||||
product__in=products
|
|
||||||
)
|
|
||||||
# discounts that match categories for provided products
|
|
||||||
product_category_discounts = conditions.DiscountForCategory.objects
|
|
||||||
product_category_discounts = product_category_discounts.filter(
|
|
||||||
category__in=(product.category for product in products)
|
|
||||||
)
|
|
||||||
# (Not relevant: discounts that match products in provided categories)
|
|
||||||
|
|
||||||
product_discounts = product_discounts.select_related(
|
|
||||||
"product",
|
"product",
|
||||||
"product__category",
|
"product__category",
|
||||||
)
|
)
|
||||||
|
category_clauses = conditions.DiscountForCategory.objects.all()
|
||||||
all_category_discounts = (
|
category_clauses = category_clauses.select_related(
|
||||||
category_discounts | product_category_discounts
|
|
||||||
)
|
|
||||||
all_category_discounts = all_category_discounts.select_related(
|
|
||||||
"category",
|
"category",
|
||||||
)
|
"discount",
|
||||||
|
|
||||||
valid_discounts = conditions.DiscountBase.objects.filter(
|
|
||||||
Q(discountforproduct__in=product_discounts) |
|
|
||||||
Q(discountforcategory__in=all_category_discounts)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_subsets = []
|
all_subsets = []
|
||||||
|
|
||||||
for discounttype in discounttypes:
|
for discounttype in discounttypes:
|
||||||
discounts = discounttype.objects.filter(id__in=valid_discounts)
|
discounts = discounttype.objects.all()
|
||||||
ctrl = ConditionController.for_type(discounttype)
|
ctrl = ConditionController.for_type(discounttype)
|
||||||
discounts = ctrl.pre_filter(discounts, user)
|
discounts = ctrl.pre_filter(discounts, user)
|
||||||
all_subsets.append(discounts)
|
all_subsets.append(discounts)
|
||||||
|
@ -145,8 +142,8 @@ class DiscountController(object):
|
||||||
from_filter = dict((i.id, i) for i in filtered_discounts)
|
from_filter = dict((i.id, i) for i in filtered_discounts)
|
||||||
|
|
||||||
clause_sets = (
|
clause_sets = (
|
||||||
product_discounts.filter(discount__in=filtered_discounts),
|
product_clauses.filter(discount__in=filtered_discounts),
|
||||||
all_category_discounts.filter(discount__in=filtered_discounts),
|
category_clauses.filter(discount__in=filtered_discounts),
|
||||||
)
|
)
|
||||||
|
|
||||||
clause_sets = (
|
clause_sets = (
|
||||||
|
|
|
@ -6,6 +6,7 @@ from collections import namedtuple
|
||||||
from django.db.models import Count
|
from django.db.models import Count
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
|
|
||||||
|
from .batch import BatchController
|
||||||
from .conditions import ConditionController
|
from .conditions import ConditionController
|
||||||
|
|
||||||
from registrasion.models import conditions
|
from registrasion.models import conditions
|
||||||
|
@ -47,8 +48,6 @@ class FlagController(object):
|
||||||
a list is returned containing all of the products that are *not
|
a list is returned containing all of the products that are *not
|
||||||
enabled*. '''
|
enabled*. '''
|
||||||
|
|
||||||
print "GREPME: test_flags()"
|
|
||||||
|
|
||||||
if products is not None and product_quantities is not None:
|
if products is not None and product_quantities is not None:
|
||||||
raise ValueError("Please specify only products or "
|
raise ValueError("Please specify only products or "
|
||||||
"product_quantities")
|
"product_quantities")
|
||||||
|
@ -62,7 +61,7 @@ class FlagController(object):
|
||||||
|
|
||||||
if products:
|
if products:
|
||||||
# Simplify the query.
|
# Simplify the query.
|
||||||
all_conditions = cls._filtered_flags(user, products)
|
all_conditions = cls._filtered_flags(user)
|
||||||
else:
|
else:
|
||||||
all_conditions = []
|
all_conditions = []
|
||||||
|
|
||||||
|
@ -86,6 +85,8 @@ class FlagController(object):
|
||||||
# from the categories covered by this condition
|
# from the categories covered by this condition
|
||||||
|
|
||||||
ids = [product.id for product in products]
|
ids = [product.id for product in products]
|
||||||
|
|
||||||
|
# TODO: This is re-evaluated a lot.
|
||||||
all_products = inventory.Product.objects.filter(id__in=ids)
|
all_products = inventory.Product.objects.filter(id__in=ids)
|
||||||
cond = (
|
cond = (
|
||||||
Q(flagbase_set=condition) |
|
Q(flagbase_set=condition) |
|
||||||
|
@ -117,7 +118,7 @@ class FlagController(object):
|
||||||
if not met and product not in messages:
|
if not met and product not in messages:
|
||||||
messages[product] = message
|
messages[product] = message
|
||||||
|
|
||||||
total_flags = FlagCounter.count()
|
total_flags = FlagCounter.count(user)
|
||||||
|
|
||||||
valid = {}
|
valid = {}
|
||||||
|
|
||||||
|
@ -160,7 +161,8 @@ class FlagController(object):
|
||||||
return error_fields
|
return error_fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _filtered_flags(cls, user, products):
|
@BatchController.memoise
|
||||||
|
def _filtered_flags(cls, user):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -171,26 +173,15 @@ class FlagController(object):
|
||||||
types = list(ConditionController._controllers())
|
types = list(ConditionController._controllers())
|
||||||
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
|
flagtypes = [i for i in types if issubclass(i, conditions.FlagBase)]
|
||||||
|
|
||||||
# Get all flags for the products and categories.
|
|
||||||
prods = (
|
|
||||||
product.flagbase_set.all()
|
|
||||||
for product in products
|
|
||||||
)
|
|
||||||
cats = (
|
|
||||||
category.flagbase_set.all()
|
|
||||||
for category in set(product.category for product in products)
|
|
||||||
)
|
|
||||||
all_flags = reduce(operator.or_, itertools.chain(prods, cats))
|
|
||||||
|
|
||||||
all_subsets = []
|
all_subsets = []
|
||||||
|
|
||||||
for flagtype in flagtypes:
|
for flagtype in flagtypes:
|
||||||
flags = flagtype.objects.filter(id__in=all_flags)
|
flags = flagtype.objects.all()
|
||||||
ctrl = ConditionController.for_type(flagtype)
|
ctrl = ConditionController.for_type(flagtype)
|
||||||
flags = ctrl.pre_filter(flags, user)
|
flags = ctrl.pre_filter(flags, user)
|
||||||
all_subsets.append(flags)
|
all_subsets.append(flags)
|
||||||
|
|
||||||
return itertools.chain(*all_subsets)
|
return list(itertools.chain(*all_subsets))
|
||||||
|
|
||||||
|
|
||||||
ConditionAndRemainder = namedtuple(
|
ConditionAndRemainder = namedtuple(
|
||||||
|
@ -220,11 +211,11 @@ _ConditionsCount = namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should be cacheable.
|
|
||||||
class FlagCounter(_FlagCounter):
|
class FlagCounter(_FlagCounter):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def count(cls):
|
@BatchController.memoise
|
||||||
|
def count(cls, user):
|
||||||
# Get the count of how many conditions should exist per product
|
# Get the count of how many conditions should exist per product
|
||||||
flagbases = conditions.FlagBase.objects
|
flagbases = conditions.FlagBase.objects
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from django.db.models import Value
|
||||||
from registrasion.models import commerce
|
from registrasion.models import commerce
|
||||||
from registrasion.models import inventory
|
from registrasion.models import inventory
|
||||||
|
|
||||||
|
from .batch import BatchController
|
||||||
from .category import CategoryController
|
from .category import CategoryController
|
||||||
from .flag import FlagController
|
from .flag import FlagController
|
||||||
|
|
||||||
|
@ -34,18 +35,14 @@ class ProductController(object):
|
||||||
if products is not None:
|
if products is not None:
|
||||||
all_products = set(itertools.chain(all_products, products))
|
all_products = set(itertools.chain(all_products, products))
|
||||||
|
|
||||||
categories = set(product.category for product in all_products)
|
category_remainders = CategoryController.user_remainders(user)
|
||||||
r = CategoryController.attach_user_remainders(user, categories)
|
product_remainders = ProductController.user_remainders(user)
|
||||||
cat_quants = dict((c, c) for c in r)
|
|
||||||
|
|
||||||
r = ProductController.attach_user_remainders(user, all_products)
|
|
||||||
prod_quants = dict((p, p) for p in r)
|
|
||||||
|
|
||||||
passed_limits = set(
|
passed_limits = set(
|
||||||
product
|
product
|
||||||
for product in all_products
|
for product in all_products
|
||||||
if cat_quants[product.category].remainder > 0
|
if category_remainders[product.category.id] > 0
|
||||||
if prod_quants[product].remainder > 0
|
if product_remainders[product.id] > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_and_messages = FlagController.test_flags(
|
failed_and_messages = FlagController.test_flags(
|
||||||
|
@ -59,17 +56,16 @@ class ProductController(object):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def attach_user_remainders(cls, user, products):
|
@BatchController.memoise
|
||||||
|
def user_remainders(cls, user):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
queryset(inventory.Product): A queryset containing items from
|
Mapping[int->int]: A dictionary that maps the product ID to the
|
||||||
``product``, with an extra attribute -- remainder = the amount of
|
user's remainder for that product.
|
||||||
this item that is remaining.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
ids = [product.id for product in products]
|
products = inventory.Product.objects.all()
|
||||||
products = inventory.Product.objects.filter(id__in=ids)
|
|
||||||
|
|
||||||
cart_filter = (
|
cart_filter = (
|
||||||
Q(productitem__cart__user=user) &
|
Q(productitem__cart__user=user) &
|
||||||
|
@ -93,12 +89,4 @@ class ProductController(object):
|
||||||
|
|
||||||
products = products.annotate(remainder=remainder)
|
products = products.annotate(remainder=remainder)
|
||||||
|
|
||||||
return products
|
return dict((product.id, product.remainder) for product in products)
|
||||||
|
|
||||||
def user_quantity_remaining(self, user):
|
|
||||||
''' Returns the quantity of this product that the user add in the
|
|
||||||
current cart. '''
|
|
||||||
|
|
||||||
with_remainders = self.attach_user_remainders(user, [self.product])
|
|
||||||
|
|
||||||
return with_remainders[0].remainder
|
|
||||||
|
|
144
registrasion/tests/test_batch.py
Normal file
144
registrasion/tests/test_batch.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
import datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
|
|
||||||
|
from controller_helpers import TestingCartController
|
||||||
|
from test_cart import RegistrationCartTestCase
|
||||||
|
|
||||||
|
from registrasion.controllers.batch import BatchController
|
||||||
|
from registrasion.controllers.discount import DiscountController
|
||||||
|
from registrasion.controllers.product import ProductController
|
||||||
|
from registrasion.models import commerce
|
||||||
|
from registrasion.models import conditions
|
||||||
|
|
||||||
|
UTC = pytz.timezone('UTC')
|
||||||
|
|
||||||
|
|
||||||
|
class BatchTestCase(RegistrationCartTestCase):
|
||||||
|
|
||||||
|
def test_no_caches_outside_of_batches(self):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_2)
|
||||||
|
|
||||||
|
# Identity testing is important here
|
||||||
|
self.assertIsNot(cache_1, cache_2)
|
||||||
|
|
||||||
|
def test_cache_clears_at_batch_exit(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIsNot(cache_1, cache_2)
|
||||||
|
|
||||||
|
def test_caches_identical_within_nestings(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_2):
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
cache_3 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIs(cache_1, cache_2)
|
||||||
|
self.assertIs(cache_2, cache_3)
|
||||||
|
|
||||||
|
def test_caches_are_independent_for_different_users(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_2):
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_2)
|
||||||
|
|
||||||
|
self.assertIsNot(cache_1, cache_2)
|
||||||
|
|
||||||
|
def test_cache_clears_are_independent_for_different_users(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_2):
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_2)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_2):
|
||||||
|
cache_3 = BatchController.get_cache(self.USER_2)
|
||||||
|
|
||||||
|
cache_4 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIs(cache_1, cache_4)
|
||||||
|
self.assertIsNot(cache_1, cache_2)
|
||||||
|
self.assertIsNot(cache_2, cache_3)
|
||||||
|
|
||||||
|
def test_new_caches_for_new_batches(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_1 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_2 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cache_3 = BatchController.get_cache(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIs(cache_2, cache_3)
|
||||||
|
self.assertIsNot(cache_1, cache_2)
|
||||||
|
|
||||||
|
def test_memoisation_happens_in_batch_context(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
output_1 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
output_2 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIs(output_1, output_2)
|
||||||
|
|
||||||
|
def test_memoisaion_does_not_happen_outside_batch_context(self):
|
||||||
|
output_1 = self._memoiseme(self.USER_1)
|
||||||
|
output_2 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIsNot(output_1, output_2)
|
||||||
|
|
||||||
|
def test_memoisation_is_user_independent(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
output_1 = self._memoiseme(self.USER_1)
|
||||||
|
with BatchController.batch(self.USER_2):
|
||||||
|
output_2 = self._memoiseme(self.USER_2)
|
||||||
|
output_3 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIsNot(output_1, output_2)
|
||||||
|
self.assertIs(output_1, output_3)
|
||||||
|
|
||||||
|
def test_memoisation_clears_outside_batches(self):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
output_1 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
output_2 = self._memoiseme(self.USER_1)
|
||||||
|
|
||||||
|
self.assertIsNot(output_1, output_2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@BatchController.memoise
|
||||||
|
def _memoiseme(self, user):
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def test_batch_end_functionality_is_called(self):
|
||||||
|
class Ender(object):
|
||||||
|
end_count = 0
|
||||||
|
def end_batch(self):
|
||||||
|
self.end_count += 1
|
||||||
|
|
||||||
|
@BatchController.memoise
|
||||||
|
def get_ender(user):
|
||||||
|
return Ender()
|
||||||
|
|
||||||
|
# end_batch should get called once on exiting the batch
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
ender = get_ender(self.USER_1)
|
||||||
|
self.assertEquals(1, ender.end_count)
|
||||||
|
|
||||||
|
# end_batch should get called once on exiting the batch
|
||||||
|
# no matter how deep the object gets cached
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
ender = get_ender(self.USER_1)
|
||||||
|
self.assertEquals(1, ender.end_count)
|
|
@ -12,6 +12,7 @@ from registrasion.models import commerce
|
||||||
from registrasion.models import conditions
|
from registrasion.models import conditions
|
||||||
from registrasion.models import inventory
|
from registrasion.models import inventory
|
||||||
from registrasion.models import people
|
from registrasion.models import people
|
||||||
|
from registrasion.controllers.batch import BatchController
|
||||||
from registrasion.controllers.product import ProductController
|
from registrasion.controllers.product import ProductController
|
||||||
|
|
||||||
from controller_helpers import TestingCartController
|
from controller_helpers import TestingCartController
|
||||||
|
@ -360,3 +361,65 @@ class BasicCartTests(RegistrationCartTestCase):
|
||||||
|
|
||||||
def test_available_products_respects_product_limits(self):
|
def test_available_products_respects_product_limits(self):
|
||||||
self.__available_products_test(self.PROD_4, 6)
|
self.__available_products_test(self.PROD_4, 6)
|
||||||
|
|
||||||
|
def test_cart_controller_for_user_is_memoised(self):
|
||||||
|
# - that for_user is memoised
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
cart_2 = TestingCartController.for_user(self.USER_1)
|
||||||
|
self.assertIs(cart, cart_2)
|
||||||
|
|
||||||
|
def test_cart_revision_does_not_increment_if_not_modified(self):
|
||||||
|
cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
rev_0 = cart.cart.revision
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
# Memoise the cart
|
||||||
|
same_cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
# Do nothing on exit
|
||||||
|
|
||||||
|
rev_1 = self.reget(cart.cart).revision
|
||||||
|
self.assertEqual(rev_0, rev_1)
|
||||||
|
|
||||||
|
def test_cart_revision_only_increments_at_end_of_batches(self):
|
||||||
|
cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
rev_0 = cart.cart.revision
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
# Memoise the cart
|
||||||
|
same_cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
same_cart.add_to_cart(self.PROD_1, 1)
|
||||||
|
rev_1 = self.reget(same_cart.cart).revision
|
||||||
|
|
||||||
|
rev_2 = self.reget(cart.cart).revision
|
||||||
|
|
||||||
|
self.assertEqual(rev_0, rev_1)
|
||||||
|
self.assertNotEqual(rev_0, rev_2)
|
||||||
|
|
||||||
|
def test_cart_discounts_only_calculated_at_end_of_batches(self):
|
||||||
|
def count_discounts(cart):
|
||||||
|
return cart.cart.discountitem_set.count()
|
||||||
|
|
||||||
|
cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
self.make_discount_ceiling("FLOOZLE")
|
||||||
|
count_0 = count_discounts(cart)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
# Memoise the cart
|
||||||
|
same_cart = TestingCartController.for_user(self.USER_1)
|
||||||
|
|
||||||
|
with BatchController.batch(self.USER_1):
|
||||||
|
# Memoise the cart
|
||||||
|
same_cart_2 = TestingCartController.for_user(self.USER_1)
|
||||||
|
|
||||||
|
same_cart_2.add_to_cart(self.PROD_1, 1)
|
||||||
|
count_1 = count_discounts(same_cart_2)
|
||||||
|
|
||||||
|
count_2 = count_discounts(same_cart)
|
||||||
|
|
||||||
|
count_3 = count_discounts(cart)
|
||||||
|
|
||||||
|
self.assertEqual(0, count_0)
|
||||||
|
self.assertEqual(0, count_1)
|
||||||
|
self.assertEqual(0, count_2)
|
||||||
|
self.assertEqual(1, count_3)
|
||||||
|
|
|
@ -5,9 +5,10 @@ from registrasion import util
|
||||||
from registrasion.models import commerce
|
from registrasion.models import commerce
|
||||||
from registrasion.models import inventory
|
from registrasion.models import inventory
|
||||||
from registrasion.models import people
|
from registrasion.models import people
|
||||||
from registrasion.controllers.discount import DiscountController
|
from registrasion.controllers.batch import BatchController
|
||||||
from registrasion.controllers.cart import CartController
|
from registrasion.controllers.cart import CartController
|
||||||
from registrasion.controllers.credit_note import CreditNoteController
|
from registrasion.controllers.credit_note import CreditNoteController
|
||||||
|
from registrasion.controllers.discount import DiscountController
|
||||||
from registrasion.controllers.invoice import InvoiceController
|
from registrasion.controllers.invoice import InvoiceController
|
||||||
from registrasion.controllers.product import ProductController
|
from registrasion.controllers.product import ProductController
|
||||||
from registrasion.exceptions import CartValidationError
|
from registrasion.exceptions import CartValidationError
|
||||||
|
@ -170,6 +171,7 @@ def guided_registration(request):
|
||||||
category__in=cats,
|
category__in=cats,
|
||||||
).select_related("category")
|
).select_related("category")
|
||||||
|
|
||||||
|
with BatchController.batch(request.user):
|
||||||
available_products = set(ProductController.available_products(
|
available_products = set(ProductController.available_products(
|
||||||
request.user,
|
request.user,
|
||||||
products=all_products,
|
products=all_products,
|
||||||
|
@ -181,7 +183,6 @@ def guided_registration(request):
|
||||||
attendee.save()
|
attendee.save()
|
||||||
return next_step
|
return next_step
|
||||||
|
|
||||||
with CartController.operations_batch(request.user):
|
|
||||||
for category in cats:
|
for category in cats:
|
||||||
products = [
|
products = [
|
||||||
i for i in available_products
|
i for i in available_products
|
||||||
|
@ -345,6 +346,7 @@ def product_category(request, category_id):
|
||||||
category_id = int(category_id) # Routing is [0-9]+
|
category_id = int(category_id) # Routing is [0-9]+
|
||||||
category = inventory.Category.objects.get(pk=category_id)
|
category = inventory.Category.objects.get(pk=category_id)
|
||||||
|
|
||||||
|
with BatchController.batch(request.user):
|
||||||
products = ProductController.available_products(
|
products = ProductController.available_products(
|
||||||
request.user,
|
request.user,
|
||||||
category=category,
|
category=category,
|
||||||
|
|
Loading…
Reference in a new issue