Added CORS Flask endpoint decorator

This commit is contained in:
Joar Wandborg 2013-12-20 14:51:32 +01:00
parent 3378805b80
commit 2eecd641dd
2 changed files with 99 additions and 6 deletions

View file

@ -4,7 +4,7 @@
from functools import wraps from functools import wraps
from flask import jsonify from flask import jsonify, request
from accounting.exceptions import AccountingException from accounting.exceptions import AccountingException
@ -22,3 +22,65 @@ def jsonify_exceptions(func):
return jsonify(error=exc) return jsonify(error=exc)
return wrapper return wrapper
def cors(origin_callback=None):
'''
Flask endpoint decorator.
Example:
.. code-block:: python
@app.route('/cors-endpoint', methods=['GET', 'OPTIONS'])
@cors()
def cors_endpoint():
return jsonify(message='This is accessible via a cross-origin XHR')
# Or if you want to control the domains this resource can be requested
# from via CORS:
domains = ['http://wandborg.se', 'http://sfconservancy.org']
def restrict_domains(origin):
return ' '.join(domains)
@app.route('/restricted-cors-endpoint')
@cors(restrict_domains)
def restricted_cors_endpoint():
return jsonify(
message='This is accessible from %s' % ', '.join(domains))
:param function origin_callback: A callback that takes one str() argument
containing the ``Origin`` HTTP header from the :data:`request` object.
This can be used to filter out which domains the resource can be
requested via CORS from.
'''
if origin_callback is None:
origin_callback = allow_all_origins
def decorator(func):
@wraps(func)
def wrapper(*args, **kw):
response = func(*args, **kw)
cors_headers = {
'Access-Control-Allow-Origin':
origin_callback(request.headers.get('Origin')) or '*',
'Access-Control-Allow-Credentials': 'true',
'Access-Control-Max-Age': 3600,
'Access-Control-Allow-Methods': 'POST, GET, DELETE',
'Access-Control-Allow-Headers':
'Accept, Content-Type, Connection, Cookie'
}
for key, val in cors_headers.items():
response.headers[key] = val
return response
return wrapper
return decorator
def allow_all_origins(origin):
return origin

View file

@ -10,16 +10,17 @@ import sys
import logging import logging
import argparse import argparse
from flask import Flask, jsonify, request from flask import Flask, jsonify, request, render_template
from flask.ext.script import Manager from flask.ext.script import Manager
from flask.ext.migrate import Migrate, MigrateCommand from flask.ext.migrate import Migrate, MigrateCommand
from accounting.models import Transaction
from accounting.storage import Storage from accounting.storage import Storage
from accounting.storage.ledgercli import Ledger from accounting.storage.ledgercli import Ledger
from accounting.storage.sql import SQLStorage from accounting.storage.sql import SQLStorage
from accounting.transport import AccountingEncoder, AccountingDecoder from accounting.transport import AccountingEncoder, AccountingDecoder
from accounting.exceptions import AccountingException from accounting.exceptions import AccountingException
from accounting.decorators import jsonify_exceptions from accounting.decorators import jsonify_exceptions, cors
app = Flask('accounting') app = Flask('accounting')
@ -56,16 +57,41 @@ def index():
''' Hello World! ''' ''' Hello World! '''
return 'Hello World!' return 'Hello World!'
@app.route('/client')
def client():
return render_template('client.html')
@app.route('/transaction', methods=['OPTIONS'])
@cors()
@jsonify_exceptions
def transaction_options():
return jsonify(status='OPTIONS')
@app.route('/transaction/<string:transaction_id>', methods=['OPTIONS'])
@cors()
@jsonify_exceptions
def transaction_by_id_options(transaction_id=None):
return jsonify(status='OPTIONS')
@app.route('/transaction', methods=['GET']) @app.route('/transaction', methods=['GET'])
def transaction_get(): @app.route('/transaction/<string:transaction_id>', methods=['GET'])
@cors()
@jsonify_exceptions
def transaction_get(transaction_id=None):
''' '''
Returns the JSON-serialized output of :meth:`accounting.Ledger.reg` Returns the JSON-serialized output of :meth:`accounting.Ledger.reg`
''' '''
return jsonify(transactions=storage.get_transactions()) if transaction_id is None:
return jsonify(transactions=storage.get_transactions())
return jsonify(transaction=storage.get_transaction(transaction_id))
@app.route('/transaction/<string:transaction_id>', methods=['POST']) @app.route('/transaction/<string:transaction_id>', methods=['POST'])
@cors()
@jsonify_exceptions @jsonify_exceptions
def transaction_update(transaction_id=None): def transaction_update(transaction_id=None):
if transaction_id is None: if transaction_id is None:
@ -85,6 +111,7 @@ def transaction_update(transaction_id=None):
@app.route('/transaction/<string:transaction_id>', methods=['DELETE']) @app.route('/transaction/<string:transaction_id>', methods=['DELETE'])
@cors()
@jsonify_exceptions @jsonify_exceptions
def transaction_delete(transaction_id=None): def transaction_delete(transaction_id=None):
if transaction_id is None: if transaction_id is None:
@ -96,6 +123,7 @@ def transaction_delete(transaction_id=None):
@app.route('/transaction', methods=['POST']) @app.route('/transaction', methods=['POST'])
@cors()
@jsonify_exceptions @jsonify_exceptions
def transaction_post(): def transaction_post():
''' '''
@ -145,7 +173,10 @@ def transaction_post():
Income:Foo:Donation $ -100 Income:Foo:Donation $ -100
Assets:Checking $ 100 Assets:Checking $ 100
''' '''
transactions = request.json.get('transactions') if not isinstance(request.json, Transaction):
transactions = request.json.get('transactions')
else:
transactions = [request.json]
if not transactions: if not transactions:
raise AccountingException('No transaction data provided') raise AccountingException('No transaction data provided')