Adds an operations_batch context manager that allows batches of modifying operations to be nested. Closes #44.
This commit is contained in:
parent
162db24817
commit
587e6e20b2
2 changed files with 75 additions and 18 deletions
|
@ -1,4 +1,5 @@
|
|||
import collections
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
|
@ -23,11 +24,18 @@ from .product import ProductController
|
|||
|
||||
def _modifies_cart(func):
|
||||
''' Decorator that makes the wrapped function raise ValidationError
|
||||
if we're doing something that could modify the cart. '''
|
||||
if we're doing something that could modify the cart.
|
||||
|
||||
It also wraps the execution of this function in a database transaction,
|
||||
and marks the boundaries of a cart operations batch.
|
||||
'''
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *a, **k):
|
||||
self._fail_if_cart_is_not_active()
|
||||
with transaction.atomic():
|
||||
with CartController.operations_batch(self.cart.user) as mark:
|
||||
mark.mark = True # Marker that we've modified the cart
|
||||
return func(self, *a, **k)
|
||||
|
||||
return inner
|
||||
|
@ -56,13 +64,57 @@ class CartController(object):
|
|||
)
|
||||
return cls(existing)
|
||||
|
||||
|
||||
# Marks the carts that are currently in batches
|
||||
_BATCH_COUNT = collections.defaultdict(int)
|
||||
_MODIFIED_CARTS = set()
|
||||
|
||||
class _ModificationMarker(object):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def operations_batch(cls, user):
|
||||
''' Marks the boundary for a batch of operations on a user's cart.
|
||||
|
||||
These markers can be nested. Only on exiting the outermost marker will
|
||||
a batch be ended.
|
||||
|
||||
When a batch is ended, discounts are recalculated, and the cart's
|
||||
revision is increased.
|
||||
'''
|
||||
|
||||
ctrl = cls.for_user(user)
|
||||
_id = ctrl.cart.id
|
||||
|
||||
cls._BATCH_COUNT[_id] += 1
|
||||
try:
|
||||
success = False
|
||||
|
||||
marker = cls._ModificationMarker()
|
||||
yield marker
|
||||
|
||||
if hasattr(marker, "mark"):
|
||||
cls._MODIFIED_CARTS.add(_id)
|
||||
|
||||
success = True
|
||||
finally:
|
||||
|
||||
cls._BATCH_COUNT[_id] -= 1
|
||||
|
||||
# Only end on the outermost batch marker, and only if
|
||||
# it excited cleanly, and a modification occurred
|
||||
modified = _id in cls._MODIFIED_CARTS
|
||||
if modified and cls._BATCH_COUNT[_id] == 0 and success:
|
||||
ctrl._end_batch()
|
||||
cls._MODIFIED_CARTS.remove(_id)
|
||||
|
||||
def _fail_if_cart_is_not_active(self):
|
||||
self.cart.refresh_from_db()
|
||||
if self.cart.status != commerce.Cart.STATUS_ACTIVE:
|
||||
raise ValidationError("You can only amend active carts.")
|
||||
|
||||
@_modifies_cart
|
||||
def extend_reservation(self):
|
||||
def _autoextend_reservation(self):
|
||||
''' Updates the cart's time last updated value, which is used to
|
||||
determine whether the cart has reserved the items and discounts it
|
||||
holds. '''
|
||||
|
@ -84,21 +136,26 @@ class CartController(object):
|
|||
self.cart.time_last_updated = timezone.now()
|
||||
self.cart.reservation_duration = max(reservations)
|
||||
|
||||
@_modifies_cart
|
||||
def end_batch(self):
|
||||
def _end_batch(self):
|
||||
''' Performs operations that occur occur at the end of a batch of
|
||||
product changes/voucher applications etc.
|
||||
THIS SHOULD BE PRIVATE
|
||||
|
||||
You need to call this after you've finished modifying the user's cart.
|
||||
This is normally done by wrapping a block of code using
|
||||
``operations_batch``.
|
||||
|
||||
'''
|
||||
|
||||
self.recalculate_discounts()
|
||||
|
||||
self.extend_reservation()
|
||||
self.cart.refresh_from_db()
|
||||
|
||||
self._recalculate_discounts()
|
||||
|
||||
self._autoextend_reservation()
|
||||
self.cart.revision += 1
|
||||
self.cart.save()
|
||||
|
||||
@_modifies_cart
|
||||
@transaction.atomic
|
||||
def set_quantities(self, product_quantities):
|
||||
''' Sets the quantities on each of the products on each of the
|
||||
products specified. Raises an exception (ValidationError) if a limit
|
||||
|
@ -140,8 +197,6 @@ class CartController(object):
|
|||
|
||||
items_in_cart.filter(quantity=0).delete()
|
||||
|
||||
self.end_batch()
|
||||
|
||||
def _test_limits(self, product_quantities):
|
||||
''' Tests that the quantity changes we intend to make do not violate
|
||||
the limits and flag conditions imposed on the products. '''
|
||||
|
@ -213,7 +268,6 @@ class CartController(object):
|
|||
|
||||
# If successful...
|
||||
self.cart.vouchers.add(voucher)
|
||||
self.end_batch()
|
||||
|
||||
def _test_voucher(self, voucher):
|
||||
''' Tests whether this voucher is allowed to be applied to this cart.
|
||||
|
@ -331,7 +385,6 @@ class CartController(object):
|
|||
raise ValidationError(errors)
|
||||
|
||||
@_modifies_cart
|
||||
@transaction.atomic
|
||||
def fix_simple_errors(self):
|
||||
''' This attempts to fix the easy errors raised by ValidationError.
|
||||
This includes removing items from the cart that are no longer
|
||||
|
@ -363,11 +416,9 @@ class CartController(object):
|
|||
|
||||
self.set_quantities(zeros)
|
||||
|
||||
@_modifies_cart
|
||||
@transaction.atomic
|
||||
def recalculate_discounts(self):
|
||||
''' Calculates all of the discounts available for this product.
|
||||
'''
|
||||
def _recalculate_discounts(self):
|
||||
''' Calculates all of the discounts available for this product.'''
|
||||
|
||||
# Delete the existing entries.
|
||||
commerce.DiscountItem.objects.filter(cart=self.cart).delete()
|
||||
|
|
|
@ -29,6 +29,7 @@ class InvoiceController(ForId, object):
|
|||
If such an invoice does not exist, the cart is validated, and if valid,
|
||||
an invoice is generated.'''
|
||||
|
||||
cart.refresh_from_db()
|
||||
try:
|
||||
invoice = commerce.Invoice.objects.exclude(
|
||||
status=commerce.Invoice.STATUS_VOID,
|
||||
|
@ -74,6 +75,8 @@ class InvoiceController(ForId, object):
|
|||
def _generate(cls, cart):
|
||||
''' Generates an invoice for the given cart. '''
|
||||
|
||||
cart.refresh_from_db()
|
||||
|
||||
issued = timezone.now()
|
||||
reservation_limit = cart.reservation_duration + cart.time_last_updated
|
||||
# Never generate a due time that is before the issue time
|
||||
|
@ -251,6 +254,9 @@ class InvoiceController(ForId, object):
|
|||
def _invoice_matches_cart(self):
|
||||
''' Returns true if there is no cart, or if the revision of this
|
||||
invoice matches the current revision of the cart. '''
|
||||
|
||||
self._refresh()
|
||||
|
||||
cart = self.invoice.cart
|
||||
if not cart:
|
||||
return True
|
||||
|
|
Loading…
Reference in a new issue