Adds an operations_batch context manager that allows batches of modifying operations to be nested. Closes #44.

This commit is contained in:
Christopher Neugebauer 2016-04-28 14:01:36 +10:00
parent 162db24817
commit 587e6e20b2
2 changed files with 75 additions and 18 deletions

View file

@ -1,4 +1,5 @@
import collections import collections
import contextlib
import datetime import datetime
import functools import functools
import itertools import itertools
@ -23,11 +24,18 @@ from .product import ProductController
def _modifies_cart(func): def _modifies_cart(func):
''' Decorator that makes the wrapped function raise ValidationError ''' Decorator that makes the wrapped function raise ValidationError
if we're doing something that could modify the cart. ''' if we're doing something that could modify the cart.
It also wraps the execution of this function in a database transaction,
and marks the boundaries of a cart operations batch.
'''
@functools.wraps(func) @functools.wraps(func)
def inner(self, *a, **k): def inner(self, *a, **k):
self._fail_if_cart_is_not_active() self._fail_if_cart_is_not_active()
with transaction.atomic():
with CartController.operations_batch(self.cart.user) as mark:
mark.mark = True # Marker that we've modified the cart
return func(self, *a, **k) return func(self, *a, **k)
return inner return inner
@ -56,13 +64,57 @@ class CartController(object):
) )
return cls(existing) return cls(existing)
# Marks the carts that are currently in batches
_BATCH_COUNT = collections.defaultdict(int)
_MODIFIED_CARTS = set()
class _ModificationMarker(object):
pass
@classmethod
@contextlib.contextmanager
def operations_batch(cls, user):
''' Marks the boundary for a batch of operations on a user's cart.
These markers can be nested. Only on exiting the outermost marker will
a batch be ended.
When a batch is ended, discounts are recalculated, and the cart's
revision is increased.
'''
ctrl = cls.for_user(user)
_id = ctrl.cart.id
cls._BATCH_COUNT[_id] += 1
try:
success = False
marker = cls._ModificationMarker()
yield marker
if hasattr(marker, "mark"):
cls._MODIFIED_CARTS.add(_id)
success = True
finally:
cls._BATCH_COUNT[_id] -= 1
# Only end on the outermost batch marker, and only if
# it excited cleanly, and a modification occurred
modified = _id in cls._MODIFIED_CARTS
if modified and cls._BATCH_COUNT[_id] == 0 and success:
ctrl._end_batch()
cls._MODIFIED_CARTS.remove(_id)
def _fail_if_cart_is_not_active(self): def _fail_if_cart_is_not_active(self):
self.cart.refresh_from_db() self.cart.refresh_from_db()
if self.cart.status != commerce.Cart.STATUS_ACTIVE: if self.cart.status != commerce.Cart.STATUS_ACTIVE:
raise ValidationError("You can only amend active carts.") raise ValidationError("You can only amend active carts.")
@_modifies_cart def _autoextend_reservation(self):
def extend_reservation(self):
''' Updates the cart's time last updated value, which is used to ''' Updates the cart's time last updated value, which is used to
determine whether the cart has reserved the items and discounts it determine whether the cart has reserved the items and discounts it
holds. ''' holds. '''
@ -84,21 +136,26 @@ class CartController(object):
self.cart.time_last_updated = timezone.now() self.cart.time_last_updated = timezone.now()
self.cart.reservation_duration = max(reservations) self.cart.reservation_duration = max(reservations)
@_modifies_cart def _end_batch(self):
def end_batch(self):
''' Performs operations that occur occur at the end of a batch of ''' Performs operations that occur occur at the end of a batch of
product changes/voucher applications etc. product changes/voucher applications etc.
THIS SHOULD BE PRIVATE
You need to call this after you've finished modifying the user's cart.
This is normally done by wrapping a block of code using
``operations_batch``.
''' '''
self.recalculate_discounts()
self.extend_reservation() self.cart.refresh_from_db()
self._recalculate_discounts()
self._autoextend_reservation()
self.cart.revision += 1 self.cart.revision += 1
self.cart.save() self.cart.save()
@_modifies_cart @_modifies_cart
@transaction.atomic
def set_quantities(self, product_quantities): def set_quantities(self, product_quantities):
''' Sets the quantities on each of the products on each of the ''' Sets the quantities on each of the products on each of the
products specified. Raises an exception (ValidationError) if a limit products specified. Raises an exception (ValidationError) if a limit
@ -140,8 +197,6 @@ class CartController(object):
items_in_cart.filter(quantity=0).delete() items_in_cart.filter(quantity=0).delete()
self.end_batch()
def _test_limits(self, product_quantities): def _test_limits(self, product_quantities):
''' Tests that the quantity changes we intend to make do not violate ''' Tests that the quantity changes we intend to make do not violate
the limits and flag conditions imposed on the products. ''' the limits and flag conditions imposed on the products. '''
@ -213,7 +268,6 @@ class CartController(object):
# If successful... # If successful...
self.cart.vouchers.add(voucher) self.cart.vouchers.add(voucher)
self.end_batch()
def _test_voucher(self, voucher): def _test_voucher(self, voucher):
''' Tests whether this voucher is allowed to be applied to this cart. ''' Tests whether this voucher is allowed to be applied to this cart.
@ -331,7 +385,6 @@ class CartController(object):
raise ValidationError(errors) raise ValidationError(errors)
@_modifies_cart @_modifies_cart
@transaction.atomic
def fix_simple_errors(self): def fix_simple_errors(self):
''' This attempts to fix the easy errors raised by ValidationError. ''' This attempts to fix the easy errors raised by ValidationError.
This includes removing items from the cart that are no longer This includes removing items from the cart that are no longer
@ -363,11 +416,9 @@ class CartController(object):
self.set_quantities(zeros) self.set_quantities(zeros)
@_modifies_cart
@transaction.atomic @transaction.atomic
def recalculate_discounts(self): def _recalculate_discounts(self):
''' Calculates all of the discounts available for this product. ''' Calculates all of the discounts available for this product.'''
'''
# Delete the existing entries. # Delete the existing entries.
commerce.DiscountItem.objects.filter(cart=self.cart).delete() commerce.DiscountItem.objects.filter(cart=self.cart).delete()

View file

@ -29,6 +29,7 @@ class InvoiceController(ForId, object):
If such an invoice does not exist, the cart is validated, and if valid, If such an invoice does not exist, the cart is validated, and if valid,
an invoice is generated.''' an invoice is generated.'''
cart.refresh_from_db()
try: try:
invoice = commerce.Invoice.objects.exclude( invoice = commerce.Invoice.objects.exclude(
status=commerce.Invoice.STATUS_VOID, status=commerce.Invoice.STATUS_VOID,
@ -74,6 +75,8 @@ class InvoiceController(ForId, object):
def _generate(cls, cart): def _generate(cls, cart):
''' Generates an invoice for the given cart. ''' ''' Generates an invoice for the given cart. '''
cart.refresh_from_db()
issued = timezone.now() issued = timezone.now()
reservation_limit = cart.reservation_duration + cart.time_last_updated reservation_limit = cart.reservation_duration + cart.time_last_updated
# Never generate a due time that is before the issue time # Never generate a due time that is before the issue time
@ -251,6 +254,9 @@ class InvoiceController(ForId, object):
def _invoice_matches_cart(self): def _invoice_matches_cart(self):
''' Returns true if there is no cart, or if the revision of this ''' Returns true if there is no cart, or if the revision of this
invoice matches the current revision of the cart. ''' invoice matches the current revision of the cart. '''
self._refresh()
cart = self.invoice.cart cart = self.invoice.cart
if not cart: if not cart:
return True return True