How do I use Flask-Cache #cache.cached() decorator with Flask-Restful? For example, I have a class Foo inherited from Resource, and Foo has get, post, put, and delete methods.
How can I can invalidate cached results after a POST?
#api.resource('/whatever')
class Foo(Resource):
#cache.cached(timeout=10)
def get(self):
return expensive_db_operation()
def post(self):
update_db_here()
## How do I invalidate the value cached in get()?
return something_useful()
As Flask-Cache implementation doesn't give you access to the underlying cache object, you'll have to explicitly instantiate a Redis client and use it's keys method (list all cache keys).
The cache_key method is used to override the default key generation in your cache.cached decorator.
The clear_cache method will clear only the portion of the cache corresponding to the current resource.
This is a solution that was tested only for Redis and the implementation will probably differ a little when using a different cache engine.
from app import cache # The Flask-Cache object
from config import CACHE_REDIS_HOST, CACHE_REDIS_PORT # The Flask-Cache config
from redis import Redis
from flask import request
import urllib
redis_client = Redis(CACHE_REDIS_HOST, CACHE_REDIS_PORT)
def cache_key():
args = request.args
key = request.path + '?' + urllib.urlencode([
(k, v) for k in sorted(args) for v in sorted(args.getlist(k))
])
return key
#api.resource('/whatever')
class Foo(Resource):
#cache.cached(timeout=10, key_prefix=cache_key)
def get(self):
return expensive_db_operation()
def post(self):
update_db_here()
self.clear_cache()
return something_useful()
def clear_cache(self):
# Note: we have to use the Redis client to delete key by prefix,
# so we can't use the 'cache' Flask extension for this one.
key_prefix = request.path
keys = [key for key in redis_client.keys() if key.startswith(key_prefix)]
nkeys = len(keys)
for key in keys:
redis_client.delete(key)
if nkeys > 0:
log.info("Cleared %s cache keys" % nkeys)
log.info(keys)
Yes, you can use like that.
Maybe you will still need to read: flask-cache memoize URL query string parameters as well
You can invalidate cache using cache.clear() method.
For more detials see: https://pythonhosted.org/Flask-Cache/#flask.ext.cache.Cache.clear and Clearing Cache section in https://pythonhosted.org/Flask-Cache/
##create a decarator
from werkzeug.contrib.cache import SimpleCache
CACHE_TIMEOUT = 300
cache = SimpleCache()
class cached(object):
def __init__(self, timeout=None):
self.timeout = timeout or CACHE_TIMEOUT
def __call__(self, f):
def decorator(*args, **kwargs):
response = cache.get(request.path)
if response is None:
response = f(*args, **kwargs)
cache.set(request.path, response, self.timeout)
return response
return decorator
#add this decarator to your views like below
#app.route('/buildingTotal',endpoint='buildingTotal')
#cached()
def eventAlert():
return 'something'
#app.route('/buildingTenants',endpoint='buildingTenants')
#cached()
def buildingTenants():
return 'something'
Answer from #JahMyst didn't work for me.
Flask-Cache doesn’t work with Flask restful framework. #cache.Cached & #cache.memoize can’t handle mutable objects per their documentation.
Using mutable objects (classes, etc) as part of the cache key can become tricky. It is suggested to not pass in an object instance into a memoized function. However, the memoize does perform a repr() on the passed in arguments so that if the object has a __repr__ function that returns a uniquely identifying string for that object, that will be used as part of the cache key.
Had to come-up with my own implementation. Leaving this code snippet incase someone else gets stuck with the same issue.
cache_key function converts the user req into hash.
cache_res_pickled function is being used to pickle or unpickle the data
|-flask-app
|-app.py
|-resource
|--some_resource.py
import json
import logging
import pickle
import time
import urllib
from flask import Response, abort, request
from redis import Redis
redis_client = Redis("127.0.0.1", "6379")
exp_setting_s = 1500
def json_serial(obj):
"""
JSON serializer for objects not serializable by default json code"
Args:
obj: JSON serialized object for dates
Returns:
serialized JSON data
"""
if isinstance(obj, datetime.datetime):
return obj.__str__()
def cache_key():
""" ""
Returns: Hashed string of request made by the user.
"""
args = request.args
key = (
request.path
+ "?"
+ urllib.parse.urlencode(
[(k, v) for k in sorted(args) for v in sorted(args.getlist(k))]
)
)
key_hashed = hashlib.sha256(key.encode())
return key_hashed.hexdigest()
def cache_res_pickled(data, encode):
"""
Args:
data (dict): Data in dict format
encode (Boolean): Encode (true) or decode (false) the data
Returns: Result after pickling
"""
if encode:
return pickle.dumps(data)
else:
data = pickle.loads(data)
return data
class SomeResource(Resource):
#auth.login_required
def get(self):
# Get the key for request in hashed format SHA256
key = cache_key()
result = redis_client.get(key)
def generate():
"""
A lagging generator to stream JSON so we don't have to hold everything in memory
This is a little tricky, as we need to omit the last comma to make valid JSON,
thus we use a lagging generator, similar to http://stackoverflow.com/questions/1630320/
"""
releases = res.__iter__()
try:
prev_release = next(releases) # get first result
# We have some releases. First, yield the opening json
yield '{"data": ['
# Iterate over the releases
for release in releases:
yield json.dumps(prev_release, default=json_serial) + ", "
prev_release = release
logging.info(f"For {key} # records returned = {len(res)}")
# Now yield the last iteration without comma but with the closing brackets
yield json.dumps(prev_release, default=json_serial) + "]}"
except StopIteration:
# StopIteration here means the length was zero, so yield a valid releases doc and stop
logging.info(f"For {key} # records returned = {len(res)}")
yield '{"data": []}'
if result is None:
# Secure a key on Redis server.
redis_client.set(key, cache_res_pickled({}, True), ex=exp_setting_s)
try:
# Do the querying to the DB or math here to get res. It should be in dict format as shown below
res = {"A": 1, "B": 2, "C": 2}
# Update the key on Redis server with the latest data
redis_client.set(key, cache_res_pickled(res, True), ex=exp_setting_s)
return Response(generate(), content_type="application/json")
except Exception as e:
logging.exception(e)
abort(505, description="Resource not found. error - {}".format(e))
else:
res = cache_res_pickled(result, False)
if res:
logging.info(
f"The data already exists!😊 loading the data form Redis cache for Key - {key} "
)
return Response(generate(), content_type="application/json")
else:
logging.info(
f"There is already a request for this key. But there is no data in it. Key: {key}."
)
s = time.time()
counter = 0
# loops aimlessly till the data is available on the Redis
while not any(res):
result = redis_client.get(key)
res = cache_res_pickled(result, False)
counter += 1
logging.info(
f"The data was available after {time.time() - s} seconds. Had to loop {counter} times.🤦"
)
return Response(generate(), content_type="application/json")
Inspired from durga's answer I wrote a very basic decorator which uses redis directly instead of any library.
from src.consts import config
from src.utils.external_services import redis_connector
import json
import jsons
import base64
class cached(object):
def __init__(self, req, timeout=None):
self.timeout = timeout or config.CACHE_DEFAULT_TIMEOUT
self.request = req
self.cache = redis_connector.get_redis_instance()
def __call__(self, f):
def decorator(*args, **kwargs):
redis_healthy = True
if self.cache is not None:
try:
self.cache.ping()
except Exception as ex:
redis_healthy = False
else:
redis_healthy = False
if self.request is not None and self.request.values is not None and self.request.path is not None and redis_healthy:
cache_key = "{}-{}".format(self.request.path, json.dumps(jsons.dump(self.request.values), sort_keys=True))
cache_key_base_64 = base64.b64encode(cache_key.encode("ascii")).decode("ascii")
response = self.cache.get(cache_key_base_64)
if response is None:
response = f(*args, **kwargs)
self.cache.setex(cache_key_base_64, self.timeout, jsons.dumps(response))
else:
response = json.loads(response)
else:
response = f(*args, **kwargs)
return response
return decorator
Now use this decorator on your api functions
from flask import g, request
from flask_restful import Resource
from webargs.flaskparser import use_args
class GetProducts(Resource):
#use_args(gen_args.argsGetProducts)
#cached(request)
def get(self, args):
return "hello from products"
Related
So I've been trying to implement an LRU cache for my project,
Using the python functools lru_cache.
As a reference I used this.
The following is the code is used from the reference.
def timed_lru_cache(maxsize, seconds):
def wrapper_cache(func):
func = lru_cache(maxsize=maxsize)(func)
func.lifetime = timedelta(seconds=seconds)
func.expiration = datetime.utcnow() + func.lifetime
#wraps(func)
def wrapped_func(*args, **kwargs):
if datetime.utcnow() >= func.expiration:
func.cache_clear()
func.expiration = datetime.utcnow() + func.lifetime
return func(*args, **kwargs)
return wrapped_func
return wrapper_cache
#timed_lru_cache(maxsize=config.cache_size, seconds=config.ttl)
def load_into_cache(id):
return object
In the wrapped func part, the func.cache_clear(), clears the entire cache along with all the items.
I need help to remove only elements past its expiretime after inserting.
Is there any work around?
I don't think it's so easy to adapt the existing lru_cache, and I don't think that linked method is very clear.
Instead I implemented a timed lru cache from scratch. See the docstring at the top for usage.
It stores a key based on the args and kwargs of the inputs, and manages two structures:
A mapping of key => (expiry, result)
A list of recently used, where the first item is the least recently used
Every time you try to get an item, the key is looked up in the "recently used" list. If it isn't there, it gets added to the list and the mapping. If it is there, we check if the expiry is in the past. If it is, we recalculate the result, and update. Otherwise we can just return whatever is in the mapping.
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Dict, List, Optional, Tuple
class TimedLRUCache:
""" Cache that caches results based on an expiry time, and on least recently used.
Items are eliminated first if they expire, and then if too many "recent" items are being
stored.
There are two methods of using this cache, either the `get` method`, or calling this as a
decorator. The `get` method accepts any arbitrary function, but on the parameters are
considered in the key, so it is advisable not to mix function.
>>> cache = TimedLRUCache(5)
>>> def foo(i):
... return i + 1
>>> cache.get(foo, 1) # runs foo
>>> cache.get(foo, 1) # returns the previously calculated result
As a decorator is more familiar:
>>> #TimedLRUCache(5)
... def foo(i):
... return i + 1
>>> foo(1) # runs foo
>>> foo(1) # returns the previously calculated result
Either method can allow for fine-grained control of the cache:
>>> five_second_cache = TimedLRUCache(5)
>>> #five_second_cache
... def foo(i):
... return i + 1
>>> five_second_cache.clear_cache() # resets the cache (clear every item)
>>> five_second_cache.prune() # clear invalid items
"""
_items: Dict[int, Tuple[datetime, Any]]
_recently_added: List[int]
delta: timedelta
max_size: int
def __init__(self, seconds: Optional[int] = None, max_size: Optional[int] = None):
self.delta = timedelta(seconds=seconds) if seconds else None
self.max_size = max_size
self._items = {}
self._recently_added = []
def __call__(self, func):
#wraps(func)
def wrapper(*args, **kwargs):
return self.get(func, args, kwargs)
return wrapper
#staticmethod
def _get_key(args, kwargs) -> int:
""" Get the thing we're going to use to lookup items in the cache. """
key = (args, tuple(sorted(kwargs.items())))
return hash(key)
def _update(self, key: int, item: Any) -> None:
""" Make sure an item is up to date. """
if key in self._recently_added:
self._recently_added.remove(key)
# the first item in the list is the least recently used
self._recently_added.append(key)
self._items[key] = (datetime.now() + self.delta, item)
# when this function is called, something has changed, so we can also sort out the cache
self.prune()
def prune(self):
""" Clear out everything that no longer belongs in the cache
First delete everything that has expired. Then delete everything that isn't recent (only
if there is a `max_size`).
"""
# clear out anything that no longer belongs in the cache.
current_time = datetime.now()
# first get rid of things which have expired
for key, (expiry, item) in self._items.items():
if expiry < current_time:
del self._items[key]
self._recently_added.remove(key)
# then make sure there aren't too many recent items
if self.max_size:
self._recently_added[:-self.max_size] = []
def clear_cache(self):
""" Clear everything from the cache """
self._items = {}
self._recently_added = []
def get(self, func, args, kwargs):
""" Given a function and its arguments, get the result using the cache
Get the key from the arguments of the function. If the key is in the cache, and the
expiry time of that key hasn't passed, return the result from the cache.
If the key *has* expired, or there are too many "recent" items, recalculate the result,
add it to the cache, and then return the result.
"""
key = self._get_key(args, kwargs)
current_time = datetime.now()
if key in self._recently_added:
# there is something in the cache
expiry, item = self._items.get(key)
if expiry < current_time:
# the item has expired, so we need to get the new value
new_item = func(*args, **kwargs)
self._update(key, new_item)
return new_item
else:
# we can use the existing value
return item
else:
# never seen this before, so add it
new_item = func(*args, **kwargs)
self._update(key, new_item)
return new_item
I have this Python script to control a PfSense router via FauxAPI. The problem is that when i call a function it gives an error. I think i'm calling the function wrong. Does anyone know how to call them?
Here is a link to the API i'm using: https://github.com/ndejong/pfsense_fauxapi
I have tried calling config_get(self, section=none) but that does not seem to work.
import os
import json
import base64
import urllib
import requests
import datetime
import hashlib
class PfsenseFauxapiException(Exception):
pass
class PfsenseFauxapi:
host = '172.16.1.1'
proto = None
debug = None
version = None
apikey = 'key'
apisecret = 'secret'
use_verified_https = None
def __init__(self, host, apikey, apisecret, use_verified_https=False, debug=False):
self.proto = 'https'
self.base_url = 'fauxapi/v1'
self.version = __version__
self.host = host
self.apikey = apikey
self.apisecret = apisecret
self.use_verified_https = use_verified_https
self.debug = debug
if self.use_verified_https is False:
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
def config_get(self, section=None):
config = self._api_request('GET', 'config_get')
if section is None:
return config['data']['config']
elif section in config['data']['config']:
return config['data']['config'][section]
raise PfsenseFauxapiException('Unable to complete config_get request, section is unknown', section)
def config_set(self, config, section=None):
if section is None:
config_new = config
else:
config_new = self.config_get(section=None)
config_new[section] = config
return self._api_request('POST', 'config_set', data=config_new)
def config_patch(self, config):
return self._api_request('POST', 'config_patch', data=config)
def config_reload(self):
return self._api_request('GET', 'config_reload')
def config_backup(self):
return self._api_request('GET', 'config_backup')
def config_backup_list(self):
return self._api_request('GET', 'config_backup_list')
def config_restore(self, config_file):
return self._api_request('GET', 'config_restore', params={'config_file': config_file})
def send_event(self, command):
return self._api_request('POST', 'send_event', data=[command])
def system_reboot(self):
return self._api_request('GET', 'system_reboot')
def system_stats(self):
return self._api_request('GET', 'system_stats')
def interface_stats(self, interface):
return self._api_request('GET', 'interface_stats', params={'interface': interface})
def gateway_status(self):
return self._api_request('GET', 'gateway_status')
def rule_get(self, rule_number=None):
return self._api_request('GET', 'rule_get', params={'rule_number': rule_number})
def alias_update_urltables(self, table=None):
if table is not None:
return self._api_request('GET', 'alias_update_urltables', params={'table': table})
return self._api_request('GET', 'alias_update_urltables')
def function_call(self, data):
return self._api_request('POST', 'function_call', data=data)
def system_info(self):
return self._api_request('GET', 'system_info')
def _api_request(self, method, action, params=None, data=None):
if params is None:
params = {}
if self.debug:
params['__debug'] = 'true'
url = '{proto}://{host}/{base_url}/?action={action}&{params}'.format(
proto=self.proto, host=self.host, base_url=self.base_url, action=action, params=urllib.parse.urlencode(params))
if method.upper() == 'GET':
res = requests.get(
url,
headers={'fauxapi-auth': self._generate_auth()},
verify=self.use_verified_https
)
elif method.upper() == 'POST':
res = requests.post(
url,
headers={'fauxapi-auth': self._generate_auth()},
verify=self.use_verified_https,
data=json.dumps(data)
)
else:
raise PfsenseFauxapiException('Request method not supported!', method)
if res.status_code == 404:
raise PfsenseFauxapiException('Unable to find FauxAPI on target host, is it installed?')
elif res.status_code != 200:
raise PfsenseFauxapiException('Unable to complete {}() request'.format(action), json.loads(res.text))
return self._json_parse(res.text)
def _generate_auth(self):
# auth = apikey:timestamp:nonce:HASH(apisecret:timestamp:nonce)
nonce = base64.b64encode(os.urandom(40)).decode('utf-8').replace('=', '').replace('/', '').replace('+', '')[0:8]
timestamp = datetime.datetime.utcnow().strftime('%Y%m%dZ%H%M%S')
hash = hashlib.sha256('{}{}{}'.format(self.apisecret, timestamp, nonce).encode('utf-8')).hexdigest()
return '{}:{}:{}:{}'.format(self.apikey, timestamp, nonce, hash)
def _json_parse(self, data):
try:
return json.loads(data)
except json.JSONDecodeError:
pass
raise PfsenseFauxapiException('Unable to parse response data!', data)
Without having tested the above script myself, I can conclude that yes you are calling the function wrong. The above script is rather a class that must be instantiated before any function inside can be used.
For example you could first create an object with:
pfsense = PfsenseFauxapi(host='<host>', apikey='<API key>', apisecret='<API secret>')
replacing <host>, <API key> and <API secret> with the required values
Then call the function with:
pfsense.config_get() # self is not passed
where config_get can be replaced with any function
Also note
As soon as you call pfsense = PfsenseFauxapi(...), all the code in
the __init__ function is also run as it is the constructor (which
is used to initialize all the attributes of the class).
When a function has a parameter which is parameter=something, that something is the default value when nothing is passed for that parameter. Hence why use_verified_https, debug and section do not need to be passed (unless you want to change them of course)
Here is some more information on classes if you need.
You need to create an object of the class in order to call the functions of the class. For example
x = PfsenseFauxapi() (the init method is called during contructing the object)
and then go by x.'any function'. Maybe name the variable not x for a good naming quality.
I am having a weird bug that seems related to Djangos caching.
I have a 3-step registration process:
insert personal data
insert company data
summary view and submit all data for registration
If person A walks through the process to the summary part but does not submit the form and
person B does the same, person B gets the data of person A in the summary view.
The data gets stored in a Storage object which carries the data through each step. Every new registration instanciates a new Storage object (at least it should).
While debugging I've found that Django does not call any method in the corresponding views when the cache is already warmed up (by another running registration) and I guess that's why there is no new Storage instance. Hence the cross-polution of data.
Now I'm perfectly aware that I can decorate the method with #never_cache() (which it already was) but that doesn't do the trick.
I've also found that the #never_cache decorator does not work properly prior to Django 1.9(?) as it misses some headers.
One solution that I've found was to set these headers myself with #cache_control(max_age=0, no_cache=True, no_store=True, must_revalidate=True). But that also doesn't work.
So how can I properly disable caching for these methods?
Here is some relevant code:
# views.py
def _request_storage(request, **kwargs):
try:
return getattr(request, '_registration_storage')
except AttributeError:
from .storage import Storage
storage = Storage(request, 'registration')
setattr(request, '_registration_storage', storage)
return storage
...
# Route that gets called by clicking the "register" button
#secure_required
#cache_control(max_age=0, no_cache=True, no_store=True, must_revalidate=True)
# also does not work with #never_cache()
def registration_start(request, **kwargs):
storage = _request_storage(request)
storage.clear()
storage.store_data('is_authenticated', request.user.is_authenticated())
storage.store_data('user_pk', request.user.pk if request.user.is_authenticated() else None)
return HttpResponseRedirect(reverse_i18n('registration_package', kwargs={'pk': 6}))
def registration_package(request, pk=None, **kwargs):
"""stores the package_pk in storage"""
...
def registration_personal(request, pk=None, **kwargs):
from .forms import PersonalForm
storage = _request_storage(request)
if request.method == 'POST':
form = PersonalForm(request.POST)
if form.is_valid():
storage.store_form('personal_form_data', form)
return HttpResponseRedirect(reverse_i18n('registration_company'))
else:
if request.GET.get('revalidate', False):
form = storage.retrieve_form('personal_form_data', PersonalForm)
else:
form = storage.retrieve_initial_form('personal_form_data', PersonalForm)
return render_to_response('registration/personal.html', {
'form': form,
'step': 'personal',
'previous': _previous_steps(request),
}, context_instance=RequestContext(request))
# the other steps are pretty much the same
# storage.py
class Storage(object):
def __init__(self, request, prefix):
self.request = request
self.prefix = prefix
def _debug(self):
from pprint import pprint
self._init_storage()
pprint(self.request.session[self.prefix])
def exists(self):
return self.prefix in self.request.session
def has(self, key):
if self.prefix in self.request.session:
return key in self.request.session[self.prefix]
return False
def _init_storage(self,):
if not self.prefix in self.request.session:
self.request.session[self.prefix] = {}
self.request.session.modified = True
def clear(self):
self.request.session[self.prefix] = {}
self.request.session.modified = True
def store_data(self, key, data):
self._init_storage()
self.request.session[self.prefix][key] = data
self.request.session.modified = True
def update_data(self, key, data):
self._init_storage()
if key in self.request.session[self.prefix]:
self.request.session[self.prefix][key].update(data)
self.request.session.modified = True
else:
self.store_data(key, data)
def retrieve_data(self, key, fallback=None):
self._init_storage()
return self.request.session[self.prefix].get(key, fallback)
def store_form(self, key, form):
self.store_data(key, form.data)
def retrieve_form_data(self, key):
return self.retrieve_data(key)
def retrieve_form(self, key, form_class):
data = self.retrieve_form_data(key)
form = form_class(data=data)
return form
def retrieve_initial_form(self, key, form_class):
data = self.retrieve_form_data(key)
form = form_class(initial=self.convert_form_data_to_initial(data))
return form
def convert_form_data_to_initial(self, data):
result = {}
if data is None:
return result
for key in data:
try:
values = data.getlist(key)
if len(values) > 1:
result[key] = values
else:
result[key] = data.get(key)
except AttributeError:
result[key] = data.get(key)
return result
def retrieve_process_form(self, key, form_class, initial=None):
if request.method == 'POST':
return form_class(data=request.POST)
else:
data = self.get_form_data(key)
initial = self.convert_form_data_to_initial(data) or initial
return form_class(initial=initial)
I've tested this across browser, computers and networks. When I clear the cache manually it works again.
(Easy to see with just debugging print() statements which do not get called with a warm cache.)
It can't be too hard to selectively disable caching, right?
Additional question:
Could it be that the #never_cache() decorator just prevents browser-caching and has nothing to do with the Redis cache?
In view, I accepts json keys and values in request.body. I plan to check for the existence of the json keys required (typo maybe) in another function
def checkJsonKey(form, *args):
for key in enumerate(args):
if key not in form:
return HttpResponse(status = 400) #<--
Instead of doing checking on the returned value, can this function directly return response and terminate this view function?
In my view function,
form = json.loads(request.body)
checkJsonKey(form,"user_preference","model_id", "filename")
Here's how i would handle this, using an exception and then catching it, rather than returning the response, this allows you to return a number of error conditions from your function and handle them in a try/except catcher.
def checkJsonKey(form, *args):
if type(form) not dict:
raise ValueError("Not dict")
# assuming you actually want to deal with a list of the args, and not a turple pairs from enumerate
for key in list(args):
# assuming that form is a dict
if key not in form.keys():
raise ValueError("Key not found")
return True
this would be called as
form = json.loads(request.body)
try:
checkJsonKey(form,"user_preference","model_id", "filename")
except ValueError as e:
return HttpResponseBadRequest("%s" % e)
# rest of the code
if you wanted to get really fancy, you could define your own error class and specifically catch for that
Check if the value returned from checkJsonKey is None, if it is not, then it must have returned your Response and you can return that from the view function.
def my_view(request, *args, **kwargs):
form = json.loads(request.body)
response = checkJsonKey(form, *args)
if response is not None:
return response
It won't make much difference, but if you want another method, json.loads() returns a dictionary from the json string passed to it, of which existence of keys can easily be checked by calling the form[key] syntax. You can further check the docs of the json API here.
Try this:
def checkJsonKey(form, *args):
for key in enumerate(args):
if form[key] != None:
return False
return True
def my_view(request, *args, **kwargs):
form = json.loads(request.body)
response = checkJsonKey(form, *args)
if response:
return HttpResponse(status = 200)
else:
return HttpResponse(status = 400)
I have numerous tornado.web.RequestHandler classes that test for authorized access using id and access key secure cookies. I access mongodb asynchronously with inline callbacks using gen.Task. I am having trouble figuring out a way to factor out the repetitive code because of its asynchronicity. How can I do this?
class MyHandler(RequestHandler):
#tornado.web.asynchronous
#gen.engine
def get(self):
id = self.get_secure_cookie('id', None)
accesskey = self.get_secure_cookie('accesskey', None)
if not id or not accesskey:
self.redirect('/a_public_area')
return
try:
# convert to bson id format to access mongodb
bson.objectid.ObjectId(id)
except:
# if not valid object id
self.redirect('/a_public_area')
return
found_id, error = yield gen.Task(asyncmong_client_inst.collection.find_one,
{'_id': id, 'accesskey': accesskey}, fields={'_id': 1})
if error['error']:
raise HTTPError(500)
return
if not found_id[0]:
self.redirect('/a_public_area')
return
# real business code follows
I would like to factor the above into a function that yields perhaps an HTTP status code.
Tornado has decorator #tornado.web.authenticated. Let's use it.
class BaseHandler(RequestHandler):
def get_login_url(self):
return u"/a_public_area"
#gen.engine #Not sure about this step
def get_current_user(self):
id = self.get_secure_cookie('id', None)
accesskey = self.get_secure_cookie('accesskey', None)
if not id or not accesskey:
return False
#Are you sure need this?
try:
# convert to bson id format to access mongodb
bson.objectid.ObjectId(id)
except:
# if not valid object id
return False
#I believe that you don't need asynchronous mongo on auth query, so if it's not working - replace it with sync call
found_id, error = yield gen.Task(asyncmong_client_inst.collection.find_one,
{'_id': id, 'accesskey': accesskey}, fields={'_id': 1})
if error['error']:
raise HTTPError(500)
if not found_id[0]:
return False
return found_id
class MyHandler(BaseHandler):
#tornado.web.asynchronous
#tornado.web.authenticated
#gen.engine
def get(self):
# real business code follows
Using gen everywhere - not good practice. It can turn this world in big spaghetti. Think about it.
perhaps a decorator (not tested or anything, just some ideas)
def sanitize(fn):
def _sanitize(self, *args, **kwargs):
id = self.get_secure_cookie('id', None)
accesskey = self.get_secure_cookie('accesskey', None)
if not id or not accesskey:
self.redirect('/a_public_area')
return
try:
# convert to bson id format to access mongodb
bson.objectid.ObjectId(id)
except:
# if not valid object id
self.redirect('/a_public_area')
return
return fn(self, *args, **kwargs)
return _sanitize
dunno if you can make the check_errors work with the business logic..but maybe..
def check_errors(fn):
def _check_errors(*args, **kwargs)
found_id, error = fn(*args, **kwargs)
if error['error']:
raise HTTPError(500)
return
if not found_id[0]:
self.redirect('/a_public_area')
return
return _check_errors
then
class MyHandler(RequestHandler):
#tornado.web.asynchronous
#gen.engine
#sanitize
#check_errors #..O.o decorators
def get(self):
found_id, error = yield gen.Task(asyncmong_client_inst.collection.find_one,
{'_id': id, 'accesskey': accesskey}, fields={'_id': 1})
return found_id, error
I'd like to address this general problem with gen.Task, which is that factoring out code is either impossible or extremely clumsy.
You can only do "yield gen.Task(...)" within the get() or post() method. If you want to have get() call another function foo(), and do the work in foo(), well: You can't, unless you want to write everything as a generator and chain them together in some unwieldy way. As your project gets bigger, this is going to be a huge problem.
This is a much better alternative: https://github.com/mopub/greenlet-tornado
We used this to convert a large synchronous codebase to Tornado, with almost no changes.