Graphene-Python: automatic schema generation from Django model - python

I am trying to generate a Graphene schema from a Django model. I am trying to do this by iterating through the apps then the models and then adding the appropriate attributes to the generated schema.
This is the code:
registry = {}
def register(target_class):
registry[target_class.__name__] = target_class
def c2u(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def s2p(name):
s1 = re.sub("y$", "ie", name)
return "{}s".format(s1)
class AutoSchemaMeta(type):
def __new__(meta, clsname, superclasses, attributedict):
new_class = type(clsname, superclasses, attributedict)
for app_name in new_class.app_models.split(","):
app_models = apps.get_app_config(app_name.strip()).get_models()
for model in app_models:
model_name = model._meta.model_name
_model_name = c2u(model_name)
if hasattr(new_class,_model_name):
continue
_node_class = type("{}Node".format(model_name.title()),
(DjangoObjectType,),
{"Meta":{"model": model, "interfaces": (Node,), "filter_fields": []}})
register(_node_class)
setattr(new_class, "all_{}".format(s2p(_model_name)), DjangoFilterConnectionField(_node_class))
setattr(new_class, _model_name, Node.Field(_node_class))
print(new_class.__dict__)
return new_class
class Query(metaclass=AutoSchemaMeta):
app_models = "app1,app2"
When I run my application I get an exception:
AssertionError: Found different types with the same name in the
schema: WorkflowNode, WorkflowNode.
Turns out there is a class already defined as WorkflowNode and I do not want to override it. So now I am stuck at finding out the classes that are already defined.
I am already excluding by attributes name with if hasattr(new_class,_model_name): continue but I would like to not rely on conventions and find out also all Nodeclasses that have been defined elsewhere and if they exist use them instead of the one I am creating automatically

I tried the proposed solution but it doesn't work for a lot of reasons, including metaclass conflicts with graphene.ObjectType so I created a solution that works pretty well:
you would provide the subclassed ObjectType a list of your ORM models (in my case SQLAlchemy) and it auto creates the schema. The only thing left to do would be to add special handling if you needed to add extra filtering options for any of the fields.
class SQLAlchemyAutoSchemaFactory(graphene.ObjectType):
#staticmethod
def set_fields_and_attrs(klazz, node_model, field_dict):
_name = camel_to_snake(node_model.__name__)
field_dict[f'all_{(s2p(_name))}'] = FilteredConnectionField(node_model)
field_dict[_name] = node_model.Field()
# log.info(f'interface:{node_model.__name__}')
setattr(klazz, _name, node_model.Field())
setattr(klazz, "all_{}".format(s2p(_name)), FilteredConnectionField(node_model))
#classmethod
def __init_subclass_with_meta__(
cls,
interfaces=(),
models=(),
excluded_models=(),
default_resolver=None,
_meta=None,
**options
):
if not _meta:
_meta = ObjectTypeOptions(cls)
fields = OrderedDict()
for interface in interfaces:
if issubclass(interface, SQLAlchemyInterface):
SQLAlchemyAutoSchemaFactory.set_fields_and_attrs(cls, interface, fields)
for model in excluded_models:
if model in models:
models = models[:models.index(model)] + models[models.index(model) + 1:]
possible_types = ()
for model in models:
model_name = model.__name__
_model_name = camel_to_snake(model.__name__)
if hasattr(cls, _model_name):
continue
if hasattr(cls, "all_{}".format(s2p(_model_name))):
continue
for iface in interfaces:
if issubclass(model, iface._meta.model):
model_interface = (iface,)
break
else:
model_interface = (CustomNode,)
_node_class = type(model_name,
(SQLAlchemyObjectType,),
{"Meta": {"model": model, "interfaces": model_interface, "only_fields": []}})
fields["all_{}".format(s2p(_model_name))] = FilteredConnectionField(_node_class)
setattr(cls, "all_{}".format(s2p(_model_name)), FilteredConnectionField(_node_class))
fields[_model_name] = CustomNode.Field(_node_class)
setattr(cls, _model_name, CustomNode.Field(_node_class))
possible_types += (_node_class,)
if _meta.fields:
_meta.fields.update(fields)
else:
_meta.fields = fields
_meta.schema_types = possible_types
super(SQLAlchemyAutoSchemaFactory, cls).__init_subclass_with_meta__(_meta=_meta, default_resolver=default_resolver, **options)
#classmethod
def resolve_with_filters(cls, info: ResolveInfo, model: Type[SQLAlchemyObjectType], **kwargs):
query = model.get_query(info)
for filter_name, filter_value in kwargs.items():
model_filter_column = getattr(model._meta.model, filter_name, None)
if not model_filter_column:
continue
if isinstance(filter_value, SQLAlchemyInputObjectType):
filter_model = filter_value.sqla_model
q = FilteredConnectionField.get_query(filter_model, info, sort=None, **kwargs)
# noinspection PyArgumentList
query = query.filter(model_filter_column == q.filter_by(**filter_value))
else:
query = query.filter(model_filter_column == filter_value)
return query
and you create the Query like this:
class Query(SQLAlchemyAutoSchemaFactory):
class Meta:
interfaces = (Interface1, Interface2,)
models = (*entities_for_iface1, *entities_for_iface2, *other_entities,)
excluded_models = (base_model_for_iface1, base_model_for_iface2)
create an interface like this:
class Interface1(SQLAlchemyInterface):
class Meta:
name = 'Iface1Node'
model = Iface1Model
and SQLAlchemyInterface:
class SQLAlchemyInterface(Node):
#classmethod
def __init_subclass_with_meta__(
cls,
model=None,
registry=None,
only_fields=(),
exclude_fields=(),
connection_field_factory=default_connection_field_factory,
_meta=None,
**options
):
_meta = SQLAlchemyInterfaceOptions(cls)
_meta.name = f'{cls.__name__}Node'
autoexclude_columns = exclude_autogenerated_sqla_columns(model=model)
exclude_fields += autoexclude_columns
assert is_mapped_class(model), (
"You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'
).format(cls.__name__, model)
if not registry:
registry = get_global_registry()
assert isinstance(registry, Registry), (
"The attribute registry in {} needs to be an instance of "
'Registry, received "{}".'
).format(cls.__name__, registry)
sqla_fields = yank_fields_from_attrs(
construct_fields(
model=model,
registry=registry,
only_fields=only_fields,
exclude_fields=exclude_fields,
connection_field_factory=connection_field_factory
),
_as=Field
)
if not _meta:
_meta = SQLAlchemyInterfaceOptions(cls)
_meta.model = model
_meta.registry = registry
connection = Connection.create_type(
"{}Connection".format(cls.__name__), node=cls)
assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}"
).format(connection.__name__)
_meta.connection = connection
if _meta.fields:
_meta.fields.update(sqla_fields)
else:
_meta.fields = sqla_fields
super(SQLAlchemyInterface, cls).__init_subclass_with_meta__(_meta=_meta, **options)
#classmethod
def Field(cls, *args, **kwargs): # noqa: N802
return NodeField(cls, *args, **kwargs)
#classmethod
def node_resolver(cls, only_type, root, info, id):
return cls.get_node_from_global_id(info, id, only_type=only_type)
#classmethod
def get_node_from_global_id(cls, info, global_id, only_type=None):
try:
node: DeclarativeMeta = one_or_none(session=info.context.get('session'), model=cls._meta.model, id=global_id)
return node
except Exception:
return None
#classmethod
def from_global_id(cls, global_id):
return global_id
#classmethod
def to_global_id(cls, type, id):
return id
#classmethod
def resolve_type(cls, instance, info):
if isinstance(instance, graphene.ObjectType):
return type(instance)
graphene_model = get_global_registry().get_type_for_model(type(instance))
if graphene_model:
return graphene_model
else:
raise ValueError(f'{instance} must be a SQLAlchemy model or graphene.ObjectType')

Related

dynamically creating classes with dynamically created methods()

I'm trying to create a dynamic class in python that also has dynamic properties; but I'm struggling with this.
Here's a simple non-working sample:
class baseA(object):
def __init__(self):
self._a_info = 1
def dump(self, pref=""):
print("%s %d" % (pref, self._a_info))
def b_init(self, parent_class):
parent_class.__init__(self)
def _dump(self, pref=""):
print("%s: %d" % self._b_info)
attrs = {"__init__": b_init,
"dump": lambda self, pref: _dump(self, pref=pref)}
for field in ["field1", "field2"]:
attrs["_%s" % field] = field
attrs[field] = lambda self: getattr(self, "_%s" % f)
tmpb = type("baseB",
(baseA, ),
attrs)
t = tmpb()
t.dump(pref="Field: ")
Obviously, the above doesn't work. For one thing print(t.field1) will print an unbounded method warning, since attrs[prop] is a function and not a value. (I was hoping to simulate
what #property does to methods).
What I'm trying to do is to create a class dynamically while setting properties
for it. That said, I now realize that "attrs[prop] = lambda self: getattr(self, "_%s" % prop)
is wrong as that makes attrs[prop] a function.
Is it even possible to use the type() function to create a dynamic class that has
the following property getter/setters?
So like converting the following:
class baseB(baseA):
def __init__(self):
self._field1 = "field1"
self._field2 = "field2"
self._field3 = "field3"
#property
def field1(self):
return self._field1
#field1.setter
def field1(self, in_val):
self._field1 = in_val
#property
def field2(self):
return self._field2
#field2.setter
def field2(self, in_val):
self._field2 = in_val
#property
def field3(self):
return self._field3
#field3.setter
def field3(self, in_val):
self._field3 = in_val
to
type("baseB",
(baseA, ),
someattributes_dictionary)
?
If it was a one off script, then sure I'd just do the long way; but if I need
to dynamically create different classes with different properties, the typical
'class ...' will be tedious and cumbersome.
Any clarifications appreciated,
Ed.
--
[1] - https://www.python-course.eu/python3_classes_and_type.php

Django Setting Custom Path for Storage ('str object is not callable')

I have a custom storage class I'm using and I am trying to set a different path to upload the file to but I keep getting "Str object is not callable".
Here's the code where I call my custom class:
fs = DGStorage(user_login=user_login, name=name, content=content, csv=True)
fs.path = os.path.join("/u/vnc/web", "docs", "upload", "csv")
When fs.save is called, I get the str object error (very annoying). I don't set the path in the storage class because it's set somewher else.
Custom Storage Model:
class DGStorage(FileSystemStorage):
def __init__(self, name=None, content=None, user_login=None, location=None, base_url=None,
file_permissions_mode=None, directory_permissions_mode=None, csv=False):
# FileSystemStorage.__init__(self, self.location)
self._location = location
self._base_url = base_url
self._file_permissions_mode = file_permissions_mode
self._directory_permissions_mode = directory_permissions_mode
self.name = name
self.content = content
self.user_login = str(user_login)
self.user_dir = self.user_login + datetime.now().strftime('_%Y%m%d')
self.csv = csv
super(DGStorage, self).__init__()
#property
def user_directory(self):
return safe_join(self.location, 'uploads', self.user_dir)
#property
def get_filename(self):
if not self.name:
self.name = self.content.name
name = self.name
name = name.replace(' ', '_')
return name
def path(self, name=None):
if not name:
name = self.content.name
else:
name = name.replace(' ', '_')
#if self.csv:
# return os.path.join("/u/vnc/web", "docs", "upload", "csv")
return safe_join(self.location, 'uploads', self.user_dir, name)
The self.csv doesn't seem to work either, it just halts when I set the path that way as well.
Used an if statement in the path. Wasn't passing in the file name.
Silly of me.

Python: Generic/Templated getters

I have a code that simply fetches a user/s from a database
class users:
def __init__(self):
self.engine = create_engine("mysql+pymysql://root:password#127.0.0.1/my_database")
self.connection = self.engine.connect()
self.meta = MetaData(bind=self.connection)
self.users = Table('users', self.meta, autoload = true)
def get_user_by_user_id(self, user_id):
stmt = self.users.select().where(self.users.c.user_id == user_id)
return self.connection.execute(stmt)
def get_user_by_username(self, username):
stmt = self.users.select().where(self.users.c.username == username)
return self.connection.execute(stmt)
def get_users_by_role_and_company(self, role, company)
stmt = self.users.select().where(self.users.c.role == role).where(self.users.c.company == company)
return self.connection.execute(stmt)
Now, what I want to do is to make the getters generic like so:
class users:
def __init__(self):
self.engine = create_engine("mysql+pymysql://root:password#127.0.0.1/my_database")
self.connection = self.engine.connect()
self.meta = MetaData(bind=self.connection)
self.users = Table('users', self.meta, autoload = true)
def get_user(self, **kwargs):
'''How do I make this generic function?'''
So, instead of calling something like this:
u = users()
u.get_user_by_user_id(1)
u.get_user_by_username('foo')
u.get_users_by_role_and_company('Admin', 'bar')
I would just call the generic function like so:
u = users()
u.get_user(user_id=1)
u.get_user(username='foo')
u.get_user(role='Admin', company='bar')
So far, this was what I could think of:
def get_user(**kwargs):
where_clause = ''
for key, value in kwargs.items():
where_clause += '{} == {} AND '.format(key, value)
where_clause = where_clause[:-5] # remove final AND
stmt = "SELECT * FROM {tablename} WHERE {where_clause};".format(tablename='users', where_clause=where_clause)
return self.connection.execute(stmt)
Is there any way that I could use the ORM style to create the statement?
Fully generalized, so any combination of legal field names is accepted as long as a field of that name exists in the table. The magic is in getattr, which allows us to look up the field we're interested in dynamically (and raises AttributeError if called with a non-existent field name):
def get_user(self, **kwargs):
# Basic query
stmt = self.users.select()
# Query objects can be refined piecemeal, so we just loop and
# add new clauses, assigning back to stmt to build up the query
for field, value in kwargs.items():
stmt = stmt.where(getattr(self.users.c, field) == value)
return self.connection.execute(stmt)
But you did all the hard work.. It's just a matter of combining the functions you created, initializing the input variables and throwing some if statements in the mix.
def get_user(self, user_id=0, username='', role='', company=''):
if user_id:
stmt = self.users.select().where(self.users.c.user_id == user_id)
return self.connection.execute(stmt)
elif username:
stmt = self.users.select().where(self.users.c.username == username)
return self.connection.execute(stmt)
elif role and company:
stmt = self.users.select().where(self.users.c.role == role).where(self.users.c.company == company)
return self.connection.execute(stmt)
else:
print('Not adequate information given. Please enter "ID" or "USERNAME", or "ROLE"&"COMPANY"')
return
Note that user_id has been initialized to 0 so that it has a boolean of False. If a 0 id is possible, set it directly to False instead. So, since the input cannot be 'random', is there a reason you want to do it with **kwargs?
Alternatively, if the numberof combinations are too many to code, i would go a different route (SQL-injection-valnurable script incoming!) and that is the following:
def get_user(self, query):
form_query = 'SELECT user FROM {} WHERE {}'.format(table_name, query)
# now execute it and return whatever it is you want returned
You are no longer passing variables to the function but rather a string which will appended to the query and executed.
Needless to say you have to be very careful with that.
Try something like:
def get_user(self, **kwargs):
if 'user_id' in kwargs:
(...)
elif 'username' in kwargs:
(...)
elif all(a in ['role','company'] for a in kwargs):
(...)
else:
(...)

How can I create a z3c.form that has multiple schemas?

I am using the cms Plone to build a form that contains two other schemas.
With the Group Forms, I have been able to include both of the fields of the two schemas. However, they lose all of their properties such as hidden or when I use datagridfield to build the table. What I want to be able to do is have both of these forms with their fields and on a save be able to save them in the object which the link was clicked as parent -> object 1 [top of the form] -> object 2 [bottom of the form]
Here is my python code:
class QuestionPart(group.Group):
label = u'Question Part'
fields = field.Fields(IQuestionPart)
template = ViewPageTemplateFile('questionpart_templates/view.pt')
class Question(group.Group):
label = u'Question'
fields = field.Fields(IQuestion)
template = ViewPageTemplateFile('question_templates/view.pt')
class QuestionSinglePart(group.GroupForm, form.AddForm):
grok.name('register')
grok.require('zope2.View')
grok.context(ISiteRoot)
label = u"Question with Single Part"
ignoreContext = True
enable_form_tabbing = False
groups = (Question,QuestionPart)
def update(self):
super(QuestionSinglePart, self).update()
This code displays both the fields of IQuestion, IQuestionPart without the regard to things like: form.mode(contype='hidden') or DataGridField widget.
I found a way that displays the correct form with field hints.
class QuestionSinglePart(AutoExtensibleForm, form.AddForm):
grok.require('zope2.View')
grok.context(ISiteRoot)
label = u"Question"
schema = IQuestion
additionalSchemata = (IQuestionPart,)
I feel I am still a long way off. I have talked to some people. I am now trying to use a separate form and view.
So far, I am at this point with my code:
class QuestionSinglePartForm(AutoExtensibleForm, form.Form):
ignoreContext = True
autoGroups = True
template = ViewPageTemplateFile('questionsinglepart_templates/questionsinglepartform.pt')
#property
def additionalSchemata(self):
return self._additionalSchemata
def __init__(self, context, request, schema, additional=()):
self.context = context
self.request = request
if not IInterface.providedBy(schema):
raise ValueError('Schema is not interface object')
self._schema = schema
if not all(IInterface.providedBy(s) for s in additional):
raise ValueError('Additional schema is not interface')
self._additionalSchemata = additional
class QuestionSinglePartView(object):
schema = IQuestion
additional = (IQuestionPart,)
def __init__(self, context, request):
self.context = context
self.request = request
self.form = QuestionSinglePartForm(context, request, self.schema, self.additional)
def magic(self, data, errors):
pass
"""
question = Question()
question.number = data['number']
question.questionContent = data['questionContent']
questionPart = QuestionPart()
questionPart.typeOfQuestion = data['IQuestionPart.typeOfQuestion']
questionPart.explanation = data['IQuestionPart.explanation']
questionPart.fileSize = data['IQuestionPart.fileSize']
questionPart.fileType = data['IQuestionPart.fileType']
questionPart.hints = data['IQuestionPart.hints']
questionPart.table = data['IQuestionPart.table']
questionPart.contype = data['IQuestionPart.contype']
questionPart.content = data['IQuestionPart.content']
"""
def update(self, *args, **kwargs):
if self.request.get('REQUEST_METHOD') == 'POST':
data, errors = self.form.extractData()
self.magic(data, errors)
self.formdisplay = self.form.render()
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
return self.index(*args, **kwargs)
I am struggling with the rendering of the form and that the QuestionSinglePart object does not have an index attribute.
After a couple of hours of working with some plone devs, we have figured out what was going on.
I had left out:
#property
def schema(self):
return self._schema
I needed to define an index in the view like so:
index = ViewPageTemplateFile('questionsinglepart_templates/questionsinglepart.pt')
I needed to add this to the views init:
alsoProvides(self.form, IWrappedForm)
In the update method for view I needed to call this before the formdisplay. I could also remove the data extraction and move that to the form.
def update(self, *args, **kwargs):
self.form.update(*args, **kwargs)
self.formdisplay = self.form.render()
I am still currently working to get the data to save into objects.
Only form classes that include the plone.autoform.base.AutoExtensibleForm mixin pay attention to schema form hints. Try using this as a mixin to your Group form class (and provide the schema attribute that this mixin looks for instead of fields):
from plone.autoform.base import AutoExtensibleForm
class QuestionPart(AutoExtensibleForm, group.Group):
label = u'Question Part'
schema = IQuestionPart
template = ViewPageTemplateFile('questionpart_templates/view.pt')
Here is my final code with the changes made above. There were issues with the index for the object. I needed to create a simple custom view. I forgot the property for schema on the form. I need to changed my update method for the view as well.
class QuestionSinglePartForm(AutoExtensibleForm, form.Form):
ignoreContext = True
autoGroups = False
#property
def schema(self):
return self._schema
#property
def additionalSchemata(self):
return self._additionalSchemata
def __init__(self, context, request, schema, additional=()):
self.context = context
self.request = request
if not IInterface.providedBy(schema):
raise ValueError('Schema is not interface object')
self._schema = schema
if not all(IInterface.providedBy(s) for s in additional):
raise ValueError('Additional schema is not interface')
self._additionalSchemata = additional
#button.buttonAndHandler(u'Save')
def handleSave(self, action):
data, errors = self.extractData()
if errors:
return False
obj = self.createAndAdd(data)
if obj is not None:
# mark only as finished if we get the new object
self._finishedAdd = True
IStatusMessage(self.request).addStatusMessage(_(u"Changes saved"), "info")
print data
#button.buttonAndHandler(u'Cancel')
def handleCancel(self, action):
print 'cancel'
class QuestionSinglePartView(object):
schema = IQuestion
additional = (IQuestionPart,)
index = ViewPageTemplateFile('questionsinglepart_templates/questionsinglepart.pt')
def __init__(self, context, request):
self.context = context
self.request = request
self.form = QuestionSinglePartForm(context, request, self.schema, self.additional)
alsoProvides(self.form, IWrappedForm)
def magic(self, data, errors):
pass
"""
question = Question()
question.number = data['number']
question.questionContent = data['questionContent']
questionPart = QuestionPart()
questionPart.typeOfQuestion = data['IQuestionPart.typeOfQuestion']
questionPart.explanation = data['IQuestionPart.explanation']
questionPart.fileSize = data['IQuestionPart.fileSize']
questionPart.fileType = data['IQuestionPart.fileType']
questionPart.hints = data['IQuestionPart.hints']
questionPart.table = data['IQuestionPart.table']
questionPart.contype = data['IQuestionPart.contype']
questionPart.content = data['IQuestionPart.content']
"""
def update(self, *args, **kwargs):
self.form.update(*args, **kwargs)
self.formdisplay = self.form.render()
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
return self.index(*args, **kwargs)

What's the correct way to unit test Pyramid views with a SQLAlchemy DBSession?

When writing a Pyramid unit test suite, what is the proper or appropriate way to unit test a view that does a SQLAlchemy call. Example:
def my_view(request):
DBSession.query(DeclarativeBase).all()
Would I use Mock() and patch to override the scope of DBSession to a DummyDB class of the sorts?
You can, and I'll be blogging/speaking/exampling about this soon. This is totally new stuff. Here's a sneak peek:
import mock
from sqlalchemy.sql import ClauseElement
class MockSession(mock.MagicMock):
def __init__(self, *arg, **kw):
kw.setdefault('side_effect', self._side_effect)
super(MockSession, self).__init__(*arg, **kw)
self._lookup = {}
def _side_effect(self, *arg, **kw):
if self._mock_return_value is not mock.sentinel.DEFAULT:
return self._mock_return_value
else:
return self._generate(*arg, **kw)
def _get_key(self, arg, kw):
return tuple(self._hash(a) for a in arg) + \
tuple((k, self._hash(kw[k])) for k in sorted(kw))
def _hash(self, arg):
if isinstance(arg, ClauseElement):
expr = str(arg.compile(compile_kwargs={"literal_binds": True}))
return expr
else:
assert hash(arg)
return arg
def _generate(self, *arg, **kw):
key = self._get_key(arg, kw)
if key in self._lookup:
return self._lookup[key]
else:
self._lookup[key] = ret = MockSession()
return ret
if __name__ == '__main__':
from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class Foo(Base):
__tablename__ = 'foo'
id = Column(Integer, primary_key=True)
x = Column(Integer)
y = Column(Integer)
sess = MockSession()
# write out queries as they would in the code, assign return_value
sess.query(Foo.x).filter(Foo.x == 5).first.return_value = 5
sess.query(Foo.x).filter(Foo.x == 2).first.return_value = 2
sess.query(Foo).filter(Foo.x == 2).filter_by(y=5).all.return_value = [Foo(x=2, y=5)]
sess.query(Foo).filter(Foo.x == 9).all.return_value = [Foo(x=9, y=1), Foo(x=9, y=2)]
# those queries are now replayable and will return what was assigned.
print sess.query(Foo.x).filter(Foo.x == 5).first()
print sess.query(Foo.x).filter(Foo.x == 2).first()
print sess.query(Foo).filter(Foo.x == 2).filter_by(y=5).all()
print sess.query(Foo).filter(Foo.x == 9).all()
I've actually assigned this into the global ScopedSession within setup/teardown and it works amazingly:
from myapp.model import MyScopedSession
def setUp(self):
MyScopedSession.registry.set(MockSession())
# set up some queries, we can usually use scoped_session's proxying
MyScopedSession.query(User).filter_by(id=1).first.return_value = User(id=1)
def tearDown(self):
MyScopedSession.remove()
def some_test(self):
# to get at mock methods and accessors, call the scoped_session to get at
# the registry
# ... test some thing
# test a User was added to the session
self.assertEquals(
MyScopedSession().add.mock_calls,
[mock.call(User(name='someuser'))]
)

Categories

Resources