diff --git a/accounting/decorators.py b/accounting/decorators.py index c7f1595..da4cb55 100644 --- a/accounting/decorators.py +++ b/accounting/decorators.py @@ -4,7 +4,7 @@ from functools import wraps -from flask import jsonify +from flask import jsonify, request from accounting.exceptions import AccountingException @@ -22,3 +22,65 @@ def jsonify_exceptions(func): return jsonify(error=exc) return wrapper + + +def cors(origin_callback=None): + ''' + Flask endpoint decorator. + + Example: + + .. code-block:: python + + @app.route('/cors-endpoint', methods=['GET', 'OPTIONS']) + @cors() + def cors_endpoint(): + return jsonify(message='This is accessible via a cross-origin XHR') + + # Or if you want to control the domains this resource can be requested + # from via CORS: + domains = ['http://wandborg.se', 'http://sfconservancy.org'] + + def restrict_domains(origin): + return ' '.join(domains) + + @app.route('/restricted-cors-endpoint') + @cors(restrict_domains) + def restricted_cors_endpoint(): + return jsonify( + message='This is accessible from %s' % ', '.join(domains)) + + :param function origin_callback: A callback that takes one str() argument + containing the ``Origin`` HTTP header from the :data:`request` object. + This can be used to filter out which domains the resource can be + requested via CORS from. + ''' + if origin_callback is None: + origin_callback = allow_all_origins + + def decorator(func): + @wraps(func) + def wrapper(*args, **kw): + response = func(*args, **kw) + cors_headers = { + 'Access-Control-Allow-Origin': + origin_callback(request.headers.get('Origin')) or '*', + 'Access-Control-Allow-Credentials': 'true', + 'Access-Control-Max-Age': 3600, + 'Access-Control-Allow-Methods': 'POST, GET, DELETE', + 'Access-Control-Allow-Headers': + 'Accept, Content-Type, Connection, Cookie' + } + + for key, val in cors_headers.items(): + response.headers[key] = val + + return response + + return wrapper + + return decorator + + +def allow_all_origins(origin): + return origin diff --git a/accounting/web.py b/accounting/web.py index 53d77ce..ea6ae51 100644 --- a/accounting/web.py +++ b/accounting/web.py @@ -10,16 +10,17 @@ import sys import logging import argparse -from flask import Flask, jsonify, request +from flask import Flask, jsonify, request, render_template from flask.ext.script import Manager from flask.ext.migrate import Migrate, MigrateCommand +from accounting.models import Transaction from accounting.storage import Storage from accounting.storage.ledgercli import Ledger from accounting.storage.sql import SQLStorage from accounting.transport import AccountingEncoder, AccountingDecoder from accounting.exceptions import AccountingException -from accounting.decorators import jsonify_exceptions +from accounting.decorators import jsonify_exceptions, cors app = Flask('accounting') @@ -56,16 +57,41 @@ def index(): ''' Hello World! ''' return 'Hello World!' +@app.route('/client') +def client(): + return render_template('client.html') + + +@app.route('/transaction', methods=['OPTIONS']) +@cors() +@jsonify_exceptions +def transaction_options(): + return jsonify(status='OPTIONS') + + +@app.route('/transaction/', methods=['OPTIONS']) +@cors() +@jsonify_exceptions +def transaction_by_id_options(transaction_id=None): + return jsonify(status='OPTIONS') + @app.route('/transaction', methods=['GET']) -def transaction_get(): +@app.route('/transaction/', methods=['GET']) +@cors() +@jsonify_exceptions +def transaction_get(transaction_id=None): ''' Returns the JSON-serialized output of :meth:`accounting.Ledger.reg` ''' - return jsonify(transactions=storage.get_transactions()) + if transaction_id is None: + return jsonify(transactions=storage.get_transactions()) + + return jsonify(transaction=storage.get_transaction(transaction_id)) @app.route('/transaction/', methods=['POST']) +@cors() @jsonify_exceptions def transaction_update(transaction_id=None): if transaction_id is None: @@ -85,6 +111,7 @@ def transaction_update(transaction_id=None): @app.route('/transaction/', methods=['DELETE']) +@cors() @jsonify_exceptions def transaction_delete(transaction_id=None): if transaction_id is None: @@ -96,6 +123,7 @@ def transaction_delete(transaction_id=None): @app.route('/transaction', methods=['POST']) +@cors() @jsonify_exceptions def transaction_post(): ''' @@ -145,7 +173,10 @@ def transaction_post(): Income:Foo:Donation $ -100 Assets:Checking $ 100 ''' - transactions = request.json.get('transactions') + if not isinstance(request.json, Transaction): + transactions = request.json.get('transactions') + else: + transactions = [request.json] if not transactions: raise AccountingException('No transaction data provided')