From a4a81e00ae397d6823ba7c9b2b65a056350cd0d3 Mon Sep 17 00:00:00 2001 From: Adrian Moennich Date: Wed, 20 May 2015 14:37:49 +0200 Subject: [PATCH] Pass SQLAlchemy objects cleanly to Celery tasks --- indico/core/celery/core.py | 55 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/indico/core/celery/core.py b/indico/core/celery/core.py index 650cf89e5..1924d33ad 100644 --- a/indico/core/celery/core.py +++ b/indico/core/celery/core.py @@ -20,10 +20,13 @@ import os from celery import Celery from celery.beat import PersistentScheduler +from celery.signals import before_task_publish +from sqlalchemy import inspect from indico.core.celery import CELERY_IMPORTS from indico.core.config import Config -from indico.core.db import DBMgr +from indico.core.db import DBMgr, db +from indico.util.string import return_ascii class IndicoCelery(Celery): @@ -99,6 +102,8 @@ class IndicoCelery(Celery): def __call__(s, *args, **kwargs): with self.flask_app.app_context(): with DBMgr.getInstance().global_connection(): + args = _CelerySAWrapper.unwrap_args(args) + kwargs = _CelerySAWrapper.unwrap_kwargs(kwargs) return super(IndicoTask, s).__call__(*args, **kwargs) self.Task = IndicoTask @@ -120,3 +125,51 @@ class IndicoPersistentScheduler(PersistentScheduler): else: self.app.conf['CELERYBEAT_SCHEDULE'][task_name]['schedule'] = entry super(IndicoPersistentScheduler, self).setup_schedule() + + +class _CelerySAWrapper(object): + """Wrapper to safely pass SQLAlchemy objects to tasks. + + This is achieved by passing only the model name and its PK values + through the Celery serializer and then fetching the actual objects + again when executing the task. + """ + __slots__ = ('identity_key',) + + def __init__(self, obj): + self.identity_key = inspect(obj).identity_key + + @property + def object(self): + obj = self.identity_key[0].get(self.identity_key[1]) + if obj is None: + raise ValueError('Object not in DB: {}'.format(self)) + return obj + + @return_ascii + def __repr__(self): + model, args = self.identity_key + return '<{}: {}>'.format(model.__name__, ','.join(map(repr, args))) + + @classmethod + def wrap_args(cls, args): + return tuple(cls(x) if isinstance(x, db.Model) else x for x in args) + + @classmethod + def wrap_kwargs(cls, kwargs): + return {k: cls(v) if isinstance(v, db.Model) else v for k, v in kwargs.iteritems()} + + @classmethod + def unwrap_args(cls, args): + return tuple(x.object if isinstance(x, cls) else x for x in args) + + @classmethod + def unwrap_kwargs(cls, kwargs): + return {k: v.object if isinstance(v, cls) else v for k, v in kwargs.iteritems()} + + +@before_task_publish.connect +def before_task_publish_signal(*args, **kwargs): + body = kwargs['body'] + body['args'] = _CelerySAWrapper.wrap_args(body['args']) + body['kwargs'] = _CelerySAWrapper.wrap_kwargs(body['kwargs']) -- 2.11.4.GIT