diff --git a/registrasion/controllers/credit_note.py b/registrasion/controllers/credit_note.py index e1f0ed2b..182c10e9 100644 --- a/registrasion/controllers/credit_note.py +++ b/registrasion/controllers/credit_note.py @@ -2,8 +2,12 @@ from django.db import transaction from registrasion.models import commerce +from for_id import ForId -class CreditNoteController(object): + +class CreditNoteController(ForId, object): + + __MODEL__ = commerce.CreditNote def __init__(self, credit_note): self.credit_note = credit_note diff --git a/registrasion/controllers/for_id.py b/registrasion/controllers/for_id.py new file mode 100644 index 00000000..3748b151 --- /dev/null +++ b/registrasion/controllers/for_id.py @@ -0,0 +1,24 @@ +from django.core.exceptions import ObjectDoesNotExist +from django.http import Http404 + + +class ForId(object): + ''' Mixin class that gives you new classmethods: for_id for_id_or_404. + These let you retrieve an instance of the class by specifying the model ID. + + Your subclass must define __MODEL__ as a class attribute. This will be the + model class that we wrap. There must also be a constructor that takes a + single argument: the instance of the model that we are controlling. ''' + + @classmethod + def for_id(cls, id_): + id_ = int(id_) + obj = cls.__MODEL__.objects.get(pk=id_) + return cls(obj) + + @classmethod + def for_id_or_404(cls, id_): + try: + return cls.for_id(id_) + except ObjectDoesNotExist: + return Http404 diff --git a/registrasion/controllers/invoice.py b/registrasion/controllers/invoice.py index fff97e8d..d2b6bf3f 100644 --- a/registrasion/controllers/invoice.py +++ b/registrasion/controllers/invoice.py @@ -11,9 +11,11 @@ from registrasion.models import people from cart import CartController from credit_note import CreditNoteController +from for_id import ForId +class InvoiceController(ForId, object): -class InvoiceController(object): + __MODEL__ = commerce.Invoice def __init__(self, invoice): self.invoice = invoice diff --git a/registrasion/tests/test_invoice.py b/registrasion/tests/test_invoice.py index 3a655bb1..8e1c7ac0 100644 --- a/registrasion/tests/test_invoice.py +++ b/registrasion/tests/test_invoice.py @@ -53,6 +53,20 @@ class InvoiceTestCase(RegistrationCartTestCase): self.PROD_1.price + self.PROD_2.price, invoice_2.invoice.value) + def test_invoice_controller_for_id_works(self): + current_cart = TestingCartController.for_user(self.USER_1) + current_cart.add_to_cart(self.PROD_1, 1) + + invoice = TestingInvoiceController.for_cart(current_cart.cart) + + id_ = invoice.invoice.id + + invoice1 = TestingInvoiceController.for_id(id_) + invoice2 = TestingInvoiceController.for_id(str(id_)) + + self.assertEqual(invoice.invoice, invoice1.invoice) + self.assertEqual(invoice.invoice, invoice2.invoice) + def test_create_invoice_fails_if_cart_invalid(self): self.make_ceiling("Limit ceiling", limit=1) self.set_time(datetime.datetime(2015, 01, 01, tzinfo=UTC))