I have an app where I want the user to be able to bookmark/un-bookmark a blog, but upon un-bookmarking, I don't want to remove that bookmark record. So I have an is_bookmarked property on my Bookmark model to determine whether a bookmark is active/inactive.
In my test file, I have
def test_unbookmark_a_blog_do_assign(session):
blog = create_blog(session)
bookmark = toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 1
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 0
This test passes. However, the following won't. (Only difference is I do not assign a variable for the toggle_bookmark's outcome.)
def test_unbookmark_a_blog_no_assign(session):
blog = create_blog(session)
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 1
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 0
It fails at the second assertion assert len(blog.bookmarks) == 0. Reason is that blog._bookmarks[0].is_bookmarked does not get updated outside toggle_bookmark function and is still True, making it available in blog.bookmarks. (Definition attached below)
For context, I am using classic mapping:
#dataclass
class Bookmark:
is_bookmarked: bool = True
blog_id: Optional[int] = None
#dataclass
class Blog:
_bookmarks: List[Bookmark] = field(default_factory=list)
def add_bookmark(self, bookmark):
self._bookmarks.append(bookmark)
#property
def bookmarks(self):
return [bookmark for bookmark in self._bookmarks if bookmark.is_bookmarked]
...
blog_table = Table(
"blog",
metadata,
Column("id", Integer, primary_key=True, index=True))
bookmark_table = Table(
"bookmark",
metadata,
Column("id", Integer, primary_key=True, index=True),
Column("is_bookmarked", Boolean, default=True),
Column("blog_id", ForeignKey("blog.id"), nullable=True),
)
...
mapper(
Blog,
blog_table,
properties={
"_bookmarks": relationship(Bookmark, back_populates="blog"),
},
)
mapper(
Bookmark,
bookmark_table,
properties={
"blog": relationship(Blog, back_populates="_bookmarks"),
},
)
The toggle_bookmark function:
def toggle_bookmark(db_session, *, blog_id):
blog = db_session.query(Blog).get(blog_id)
bookmark = db_session.query(Bookmark).filter(
Bookmark.blog_id == blog_id
).one_or_none()
if bookmark is None:
bookmark = Bookmark()
blog.add_bookmark(bookmark)
db_session.add(blog)
db_session.commit()
return bookmark
bookmark.is_bookmarked = not bookmark.is_bookmarked
db_session.add(bookmark)
db_session.commit()
return bookmark
I am really confused... My gut tells me that it has something to do when the query gets evaluated but I haven't managed to find any evidence to support it. Any help is appreciated. Thanks in advance!
A full example:
from dataclasses import dataclass, field
from typing import Optional, List
from sqlalchemy import (
create_engine, MetaData, Table, Column, Integer, Boolean, ForeignKey)
from sqlalchemy.orm import mapper, relationship, sessionmaker
#dataclass
class Bookmark:
is_bookmarked: bool = True
blog_id: Optional[int] = None
#dataclass
class Blog:
_bookmarks: List[Bookmark] = field(default_factory=list)
def add_bookmark(self, bookmark):
self._bookmarks.append(bookmark)
#property
def bookmarks(self):
return [bookmark for bookmark in self._bookmarks if bookmark.is_bookmarked]
engine = create_engine("sqlite:///")
metadata = MetaData(bind=engine)
blog_table = Table(
"blog",
metadata,
Column("id", Integer, primary_key=True, index=True))
bookmark_table = Table(
"bookmark",
metadata,
Column("id", Integer, primary_key=True, index=True),
Column("is_bookmarked", Boolean, default=True),
Column("blog_id", ForeignKey("blog.id"), nullable=True),
)
metadata.create_all()
mapper(
Blog,
blog_table,
properties={
"_bookmarks": relationship(Bookmark, back_populates="blog"),
},
)
mapper(
Bookmark,
bookmark_table,
properties={
"blog": relationship(Blog, back_populates="_bookmarks"),
},
)
def toggle_bookmark(db_session, *, blog_id):
blog = db_session.query(Blog).get(blog_id)
bookmark = db_session.query(Bookmark).filter(
Bookmark.blog_id == blog_id
).one_or_none()
if bookmark is None:
bookmark = Bookmark()
blog.add_bookmark(bookmark)
db_session.add(blog)
db_session.commit()
return bookmark
bookmark.is_bookmarked = not bookmark.is_bookmarked
db_session.add(bookmark)
db_session.commit()
return bookmark
def create_blog(session):
blog = Blog()
session.add(blog)
session.commit()
return blog
def test_unbookmark_a_blog_do_assign(session):
blog = create_blog(session)
bookmark = toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 1
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 0
def test_unbookmark_a_blog_no_assign(session):
blog = create_blog(session)
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 1
toggle_bookmark(session, blog_id=blog.id)
assert len(blog.bookmarks) == 0
Session = sessionmaker()
test_unbookmark_a_blog_do_assign(Session())
test_unbookmark_a_blog_no_assign(Session())
The core problem is this:
class Bookmark:
is_bookmarked: bool = True # <-- This here
Classical mapping does not install instrumentation over the existing class attribute, and so any changes to an instance's is_bookmarked are not persisted. From there it follows that without the assignment the test reads the state from the database, where it holds its default value True. With the assignment the instance is kept around in the scope of the test and so in the Session, and the later queries return the existing modified instance.
You would run into similar issues with SQLAlchemy, and dataclasses and field(), if using default=:
>>> from dataclasses import dataclass, field
>>> #dataclass
... class C:
... f: bool = field(default=True)
...
>>> C.f
True
A solution to get over the situation is to use a field() with default_factory= for is_bookmarked as well:
#dataclass
class Bookmark:
is_bookmarked: bool = field(default_factory=lambda: True)
...
since in recent enough Python the field() is then not visible in the class as an attribute, and mapping can install instrumentation.
Related
I am struggling to define methods in SQLAlchemy to retrieve related records via an intermediary table.
Consider the following schema:
Users can create multiple posts, each post belongs to 1 user
Each post can have multiple comments on it, with each comment belonging to 1 post
What I want is to be able to, for a given user instance, retrieve all of the comments from all of their posts.
I have set this up as follows:
from sqlalchemy import ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
id: Mapped[int] = mapped_column(primary_key=True)
# define model classes
class User(Base):
__tablename__ = "users"
name: Mapped[str] = mapped_column()
posts: Mapped[list["Post"]] = relationship(back_populates="user")
def __repr__(self) -> str:
return f"(<{__class__.__name__}> name: {self.name})"
class Post(Base):
__tablename__ = "posts"
title: Mapped[str] = mapped_column()
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
user: Mapped["User"] = relationship(back_populates="posts")
comments: Mapped[list["Comment"]] = relationship(back_populates="post")
def __repr__(self) -> str:
return f"(<{__class__.__name__}> title: {self.title})"
class Comment(Base):
__tablename__ = "comments"
body: Mapped[str] = mapped_column()
post_id: Mapped[int] = mapped_column(ForeignKey("posts.id"))
post: Mapped["Post"] = relationship(back_populates="comments")
def __repr__(self) -> str:
return f"(<{__class__.__name__}> body: {self.body})"
If I create a few instances of these models, you can see how things are related:
# create instances
user = User(name="greta")
post_1 = Post(title="First post", user=user)
post_2 = Post(title="Second post", user=user)
comment_1 = Comment(body="yeah wotever", post=post_1)
comment_2 = Comment(body="lol good one", post=post_1)
comment_3 = Comment(body="lmfao", post=post_2)
# show all posts, and their comments
print(user)
for post in user.posts:
print(f" └── {post}")
for comment in post.comments:
print(f" └── {comment}")
(<User> name: greta)
└── (<Post> title: First post)
└── (<Comment> body: yeah wotever)
└── (<Comment> body: lol good one)
└── (<Post> title: Second post)
└── (<Comment> body: lmfao)
I am unsure of how to use relationship() to define a method all_comments() in the User class, which would return a list of all of the comments across all of a user instance's posts.
Can anyone point me in the right direction?
Using your secondary table posts, you can use secondaryjoin and primaryjoin in relationship to get what you want.
This way you needn't create a method all_comments, you can just do user.comments and get the comments, also the other way round, comment.user gives you the user. (see edit)
You probably want to tweak the join conditions, from my rudimentary testing this seems to get where you want to be.
I have created two users, which different set of posts and comments so you can see the difference.
from sqlalchemy import ForeignKey, create_engine, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, Session
class Base(DeclarativeBase):
id: Mapped[int] = mapped_column(primary_key=True)
class User(Base):
__tablename__ = "users"
name: Mapped[str] = mapped_column()
posts: Mapped[list["Post"]] = relationship(back_populates="user")
comments: Mapped[list["Comment"]] = relationship(
back_populates="user",
secondary="posts",
primaryjoin="User.id == Post.user_id",
secondaryjoin="Comment.post_id == Post.id",
viewonly=True,
)
def __repr__(self) -> str:
return f"(<{__class__.__name__}> name: {self.name})"
class Post(Base):
__tablename__ = "posts"
title: Mapped[str] = mapped_column()
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
user: Mapped["User"] = relationship(back_populates="posts")
comments: Mapped[list["Comment"]] = relationship(back_populates="post")
def __repr__(self) -> str:
return f"(<{__class__.__name__}> title: {self.title})"
class Comment(Base):
__tablename__ = "comments"
body: Mapped[str] = mapped_column()
post_id: Mapped[int] = mapped_column(ForeignKey("posts.id"))
post: Mapped["Post"] = relationship(back_populates="comments")
user: Mapped["User"] = relationship(
back_populates="comments",
secondary="posts",
primaryjoin="User.id == Post.user_id",
secondaryjoin="Comment.post_id == Post.id",
viewonly=True,
uselist=False,
)
def __repr__(self) -> str:
return f"(<{__class__.__name__}> body: {self.body})"
engine = create_engine("sqlite:///temp.db")
Base.metadata.create_all(engine)
with Session(engine) as session, session.begin():
user = User(name="greta")
post_1 = Post(title="First post", user=user)
post_2 = Post(title="Second post", user=user)
comment_1 = Comment(body="yeah wotever", post=post_1)
comment_2 = Comment(body="lol good one", post=post_1)
comment_3 = Comment(body="lmfao", post=post_2)
session.add_all((user, post_1, post_2, comment_1, comment_2, comment_3))
user = User(name="not greta")
post_1 = Post(title="Third post", user=user)
post_2 = Post(title="Fourth post", user=user)
comment_1 = Comment(body="wotever", post=post_1)
comment_2 = Comment(body="good one", post=post_1)
session.add_all((user, post_1, post_2, comment_1, comment_2))
with Session(engine) as session:
statement = select(User)
for user in session.scalars(statement):
print(user, user.comments)
Output
(<User> name: greta) [(<Comment> body: yeah wotever), (<Comment> body: lol good one), (<Comment> body: lmfao)]
(<User> name: not greta) [(<Comment> body: wotever), (<Comment> body: good one)]
Edit: The reverse relation "get user from a comment" seems to be bugged in this implementation, one comment has more than one user, I am not sure where I went wrong, but if all you want is the relation "get all comments for a user" then this works.
I have these two tables named User and UserRole.
class UserRoleType(str, enum.Enum):
admin = 'admin'
client = 'client'
class UserRole(SQLModel, table=True):
__tablename__ = 'user_role'
id: int | None = Field(default=None, primary_key=True)
type: UserRoleType = Field(
default=UserRoleType.client,
sa_column=Column(Enum(UserRoleType)),
)
write_access: bool = Field(default=False)
read_access: bool = Field(default=False)
users: List['User'] = Relationship(back_populates='user_role')
class User(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
username: str = Field(..., index=True)
user_role_id: int = Field(..., foreign_key='user_role.id')
user_role: 'UserRole' = Relationship(back_populates='users')
I can easily insert them into the DB with:
async with get_session() as session:
role = UserRole(description=UserRoleType.client)
session.add(role)
await session.commit()
user = User( username='test', user_role_id=role.id)
session.add(user)
await session.commit()
await session.refresh(user)
And access the committed data with:
results = await session.execute(select(User).where(User.id == 1)).one()
Output:
(User(user_role_id=1, username='test', id=1),)
Notice that there's an user_role_id, but where's the user_role object?
In fact, if I try to access it, it raises:
*** AttributeError: Could not locate column in row for column 'user_role'
I also tried to pass the role instead of the user_role_id at the insertion of the User:
user = User( username='test', user_role=role)
But I got:
sqlalchemy.exc.InterfaceError: (sqlite3.InterfaceError) Error binding parameter 2 - probably unsupported type.
A few things first.
You did not include your import statements, so I will have to guess a few things.
You probably want the User.user_role_id and User.user_role fields to be "pydantically" optional. This allows you to create user instances without passing the role to the constructor, giving you the option to do so after initialization or for example by appending User objects to the UserRole.users list instead. To enforce that a user must have a role on the database level, you simply define nullable=False on the User.user_role_id field. That way, if you try to commit to the DB without having defined a user role for a user in any of the possible ways, you will get an error.
In your database insertion code you write role = UserRole(description=UserRoleType.client). I assume the description is from older code and you meant to write role = UserRole(type=UserRoleType.client).
You probably want your UserRole.type to be not nullable on the database side. You can do so by passing nullable=False to the Column constructor (not the Field constructor).
I will simplify a little by using blocking code (non-async) and a SQLite database.
This should work:
from enum import Enum as EnumPy
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Enum as EnumSQL
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine
class UserRoleType(str, EnumPy):
admin = 'admin'
client = 'client'
class UserRole(SQLModel, table=True):
__tablename__ = 'user_role'
id: int | None = Field(default=None, primary_key=True)
type: UserRoleType = Field(default=UserRoleType.client, sa_column=Column(EnumSQL(UserRoleType), nullable=False))
write_access: bool = Field(default=False)
read_access: bool = Field(default=False)
users: list['User'] = Relationship(back_populates='user_role')
class User(SQLModel, table=True):
__tablename__ = 'user'
id: int | None = Field(default=None, primary_key=True)
username: str = Field(..., index=True)
user_role_id: int | None = Field(foreign_key='user_role.id', default=None, nullable=False)
user_role: UserRole | None = Relationship(back_populates='users')
def test() -> None:
# Initialize database & session:
sqlite_file_name = 'user_role.db'
sqlite_url = f'sqlite:///{sqlite_file_name}'
engine = create_engine(sqlite_url)
SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
session = Session(engine)
# Create the test objects:
role = UserRole(type=UserRoleType.client)
user = User(username='test', user_role=role)
session.add(user)
session.commit()
session.refresh(user)
# Do some checks:
assert isinstance(user.user_role.type, EnumPy)
assert user.user_role_id == role.id and isinstance(role.id, int)
assert role.users == [user]
if __name__ == '__main__':
test()
PS: I know the question was posted a while ago, but maybe this still helps or helps someone else.
I'm trying to relate two tables that have multiple 'secondary' tables. Instead of the declarative syntax, it's necessary that I use the classical. Here is a simplified schema:
class Apple:
def __init__(self, id=None, name=None):
# ...
class Recipe:
def __init__(self, id=None, appleId=None, name=None):
# ...
class Blog:
def __init__(self, id=None, name=None, recipeId=None, bloggerId=None):
# ...
class Blogger:
def __init__(self, name)
# ...
appleTable = Table('Apple', metadata, Column('id', Integer, primary_key=True), Column('name', String(256)))
recipeTable = Table('Recipe', metadata, Column('id', Integer, primary_key=True), Column('name', String(256)), Column('appleId', Integer, ForeignKey('Apple.id')))
blogTable = Table('Blog', metadata, Column('id', Integer, primary_key=True), Column('name', String(256)), Column('recipeId', Integer, ForeignKey('Recipe.id')), Column('bloggerId', Integer, ForeignKey('Blogger.id')) )
bloggerTable = Table('Blogger', metadata, Column('id', Integer, primary_key=True), Column('name', String(256)))
# call mapper on all tables/classes
# ... #
# Relate 'Apple' to 'Blogger' using 'Recipe' and 'Blog' as intermediates
Apple.appleBloggers = relationship(Blogger, secondary=..., primaryjoin=..., secondaryjoin=...)
What relationship would I need to place into the appleBloggers attribute of Apple in order to retrieve all bloggers who've blogged about apple recipes?
Edit: Solution
An alternative #univerio 's solution is posted below. The difference being with the usage of the mapped objects vs the table variables. I've also added a viewonly parameter, which prevents writes to the attribute.
mapper(Apple, appleTable, properties = {
"appleBloggers": relationship(Blogger,
secondary=join(Recipe, Blog, Recipe.id == Blog.recipeId),
secondaryjoin=lambda: and_(Apple.id == Recipe.appleId, Blogger.id == Blog.bloggerId),
viewonly=True)
})
Original Attempt
Here is what I've tried:
Apple.appleBloggers = relationship(Blogger,
secondary=join(Recipe, Blog, Recipe.id==Blog.recipeId),
primaryjoin= Apple.id == Recipe.appleId,
secondaryjoin= Blog.bloggerId == Blogger.id)
But whenever I do the following:
apple = Apple(name="RedDelicious")
session.add(apple)
session.commit()
print(apple.appleBloggers)
I get the error:
File ".../.pyenv/versions/2.7.10/lib/python2.7/site-packages/sqlalchemy/orm/relationships.py", line 1425, in __str__
return str(self.parent.class_.__name__) + "." + self.key
File ".../.pyenv/versions/2.7.10/lib/python2.7/site-packages/sqlalchemy/util/langhelpers.py", line 840, in __getattr__
return self._fallback_getattr(key)
File ".../.pyenv/versions/2.7.10/lib/python2.7/site-packages/sqlalchemy/util/langhelpers.py", line 818, in _fallback_getattr
raise AttributeError(key)
AttributeError: parent
You're mixing declarative and classical mappings. Assigning a relationship like that only works for declarative. The proper way to do this in classical mapping is:
mapper(Apple, appleTable, properties={
"appleBloggers": relationship(Blogger,
secondary=recipeTable.join(blogTable, recipeTable.c.id == blogTable.c.recipeId),
primaryjoin=appleTable.c.id == recipeTable.c.appleId,
secondaryjoin=blogTable.c.bloggerId == bloggerTable.c.id)
})
Alternative Solution (with mapped objects):
mapper(Apple, appleTable, properties = {
"appleBloggers": relationship(Blogger,
secondary=join(Recipe, Blog, Recipe.id == Blog.recipeId),
secondaryjoin=lambda: and_(Apple.id == Recipe.appleId, Blogger.id == Blog.bloggerId),
viewonly=True)
})
I've created some base classes for an application which stores data using SQLAlchemy. One of the Base classes (Content) is polymorphic and have some general fields such as id, title, description, timestamps etc. Subclasses of this class is supposed to add additional fields which are stored in a separate table. I've created a standalone code sample which illustrates the concept better. The example contain the Base classes, some subclasses and some bootstrap code to create a sqlite database. The easiest way to get the example running by pasting the code into 'example.py', creating a virtualenv, installing SQLAlchemy into that virtualenv and using it's interpreter to run the example. The example contain some commented troublesome code, if that code is commented the example should run without errors (atleast it does here).
By uncommenting the commented code the example fails, and I'm not quite sure how to fix this - any help is superwelcome!
Example overview:
It has some base classes (Base and Content).
It has a Task class which extends Content.
A Task may have subtasks, positional ordering should persist.
It has a Project class (commented) which extends Content.
Projects have a due_date and milestones (which is a list of Tasks)
It has a Worklist class (commented) which extends Content.
Worklists belong to an 'employee' and have tasks.
What I'm trying to achieve is having Task work as a standalone class, but additional classes may also have Tasks (such as Project and Worklist). I dont want to end up with several task/related tables, but rather want to utilize Content for this concept and attach Tasks in this 'generic' way.
Example code:
from datetime import datetime
from datetime import timedelta
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Boolean
from sqlalchemy import String
from sqlalchemy import DateTime
from sqlalchemy import Date
from sqlalchemy import Unicode
from sqlalchemy import UnicodeText
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref
from sqlalchemy.util import classproperty
class Base(object):
#declared_attr
def __tablename__(cls):
return cls.__name__.lower()
#property
def columns(self):
return self.__mapper__.columns.keys()
def add(self, **data):
self.update(**data)
db_session.add(self)
db_session.flush()
def delete(self):
db_session.delete(self)
db_session.flush()
def update(self, **data):
"""
Iterate over all columns and set values from data.
"""
for attr in self.columns:
if attr in data and data[attr] is not None:
setattr(self, attr, data[attr])
engine = create_engine('sqlite:///test.db', echo=True)
metadata = MetaData()
db_session = scoped_session(sessionmaker(bind=engine))
Base = declarative_base(cls=Base)
Base.metadata = metadata
Base.query = db_session.query_property()
class Content(Base):
"""
Base class for all content. Includes basic features such as
ownership and timestamps for modification and creation.
"""
#classproperty
def __mapper_args__(cls):
return dict(
polymorphic_on='type',
polymorphic_identity=cls.__name__.lower(),
with_polymorphic='*')
id = Column(Integer(), primary_key=True)
type = Column(String(30), nullable=False)
owner = Column(Unicode(128))
title = Column(Unicode(128))
description = Column(UnicodeText())
creation_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
modification_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
def __init__(self, **data):
self.add(**data)
def update(self, touch=True, **data):
"""
Iterate over all columns and set values from data.
:param touch:
:param data:
:return:
"""
super(Content, self).update(**data)
if touch and 'modification_date' not in data:
self.modification_date = datetime.utcnow()
def __eq__(self, other):
return isinstance(other, Content) and self.id == other.id
def get_content(id):
return Content.query.get(id)
class Task(Content):
id = Column(Integer, ForeignKey(Content.id), primary_key=True)
# content_id = Column(Integer, ForeignKey(Content.id), nullable=True)
done = Column(Boolean, default=False)
position = Column(Integer, default=0)
parent_id = Column(Integer, ForeignKey('task.id'), nullable=True)
tasks = relationship(
'Task',
cascade='all, delete, delete-orphan',
backref=backref('parent', remote_side=id),
foreign_keys='Task.parent_id',
order_by=position,
collection_class=ordering_list('position', reorder_on_append=True)
)
def default_due_date():
return datetime.utcnow() + timedelta(days=60)
# class Project(Content):
#
# id = Column(Integer, ForeignKey(Content.id), primary_key=True)
# due_date = Column(Date, default=default_due_date)
#
# milestones = relationship(
# 'Task',
# cascade='all, delete, delete-orphan',
# backref=backref('content_parent', remote_side=id),
# foreign_keys='Task.content_id',
# collection_class=ordering_list('position', reorder_on_append=True)
# )
#
#
# class Worklist(Content):
#
# id = Column(Integer, ForeignKey(Content.id), primary_key=True)
# employee = Column(Unicode(128), nullable=False)
#
# tasks = relationship(
# 'Task',
# cascade='all, delete, delete-orphan',
# backref=backref('content_parent', remote_side=id),
# foreign_keys='Task.content_id',
# collection_class=ordering_list('position', reorder_on_append=True)
# )
def main():
db_session.registry.clear()
db_session.configure(bind=engine)
metadata.bind = engine
metadata.create_all(engine)
# Test basic operation
task = Task(title=u'Buy milk')
task = get_content(task.id)
# assert Content attributes inherited
assert task.title == u'Buy milk'
assert task.done == False
# add subtasks
task.tasks = [
Task(title=u'Remember to check expiration date'),
Task(title=u'Check bottle is not leaking')
]
# assert that subtasks is added and correctly ordered
task = get_content(task.id)
assert len(task.tasks) == 2
assert [(x.position, x.title) for x in task.tasks] == \
[(0, u'Remember to check expiration date'),
(1, u'Check bottle is not leaking')]
# reorder subtasks
task.tasks.insert(0, task.tasks.pop(1))
task = get_content(task.id)
assert len(task.tasks) == 2
assert [(x.position, x.title) for x in task.tasks] == \
[(0, u'Check bottle is not leaking'),
(1, u'Remember to check expiration date')]
# # Test Project implementation
# project = Project(title=u'My project')
# milestone1 = Task(title=u'Milestone #1', description=u'First milestone')
# milestone2 = Task(title=u'Milestone #2', description=u'Second milestone')
# milestone1.tasks = [Task(title=u'Subtask for Milestone #1'), ]
# milestone2.tasks = [Task(title=u'Subtask #1 for Milestone #2'),
# Task(title=u'Subtask #2 for Milestone #2')]
# project.milestones = [milestone1, milestone2]
# project = get_content(project.id)
# assert project.title == u'My project'
# assert len(project.milestones) == 2
# assert [(x.position, x.title) for x in project.milestones] == \
# [(0, u'Milestone #1'), (1, u'Milestone #2')]
# assert len(Task.query.all()) == 8
# assert isinstance(milestone1.content_parent, Project) == True
#
# # Test Worklist implementation
# worklist = Worklist(title=u'My worklist', employee=u'Torkel Lyng')
# worklist.tasks = [
# Task(title=u'Ask stackoverflow for help'),
# Task(title=u'Learn SQLAlchemy')
# ]
# worklist = get_content(worklist.id)
# assert worklist.title == u'My worklist'
# assert worklist.employee == u'Torkel Lyng'
# assert len(worklist.tasks) == 2
# assert len(Task.query.all()) == 10
# assert isinstance(worklist.tasks[0].content_parent, Worklist) == True
if __name__=='__main__':
main()
I'm sorry for this long example, wanted to supply something that worked standalone. Any help, comments on design or suggestions are greatly appretiated.
I refactored the example a bit and made it somewhat working. Instead of defining the additional ForeignKey on Task (content_id) I added it to the Content class as container_id
from datetime import datetime
from datetime import timedelta
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Boolean
from sqlalchemy import String
from sqlalchemy import DateTime
from sqlalchemy import Date
from sqlalchemy import Unicode
from sqlalchemy import UnicodeText
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref
from sqlalchemy.util import classproperty
class Base(object):
#declared_attr
def __tablename__(cls):
return cls.__name__.lower()
#property
def columns(self):
return self.__mapper__.columns.keys()
def add(self, **data):
self.update(**data)
db_session.add(self)
db_session.flush()
def delete(self):
db_session.delete(self)
db_session.flush()
def update(self, **data):
"""
Iterate over all columns and set values from data.
"""
for attr in self.columns:
if attr in data and data[attr] is not None:
setattr(self, attr, data[attr])
engine = create_engine('sqlite:///test.db', echo=True)
metadata = MetaData()
db_session = scoped_session(sessionmaker(bind=engine))
Base = declarative_base(cls=Base)
Base.metadata = metadata
Base.query = db_session.query_property()
class Content(Base):
"""
Base class for all content. Includes basic features such as
ownership and timestamps for modification and creation.
"""
#classproperty
def __mapper_args__(cls):
return dict(
polymorphic_on='type',
polymorphic_identity=cls.__name__.lower(),
with_polymorphic='*')
id = Column(Integer(), primary_key=True)
container_id = Column(Integer(), ForeignKey('content.id'), nullable=True)
# container = relationship('Content', foreign_keys=[container_id], uselist=False)
type = Column(String(30), nullable=False)
owner = Column(Unicode(128))
title = Column(Unicode(128))
description = Column(UnicodeText())
creation_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
modification_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
def __init__(self, **data):
self.add(**data)
#property
def container(self):
if self.container_id:
return get_content(self.container_id)
return None
def update(self, touch=True, **data):
"""
Iterate over all columns and set values from data.
:param touch:
:param data:
:return:
"""
super(Content, self).update(**data)
if touch and 'modification_date' not in data:
self.modification_date = datetime.utcnow()
def __eq__(self, other):
return isinstance(other, Content) and self.id == other.id
def __repr__(self):
return '<{0} "{1}">'.format(self.__class__.__name__, self.title)
def get_content(id):
return Content.query.get(id)
class Task(Content):
id = Column(Integer, ForeignKey(Content.id), primary_key=True)
done = Column(Boolean, default=False)
position = Column(Integer, default=0)
parent_id = Column(Integer, ForeignKey('task.id'), nullable=True)
tasks = relationship(
'Task',
cascade='all, delete, delete-orphan',
backref=backref('parent', remote_side=id),
foreign_keys='Task.parent_id',
order_by=position,
collection_class=ordering_list('position', reorder_on_append=True)
)
def default_due_date():
return datetime.utcnow() + timedelta(days=60)
class Project(Content):
id = Column(Integer, ForeignKey(Content.id), primary_key=True)
due_date = Column(Date, default=default_due_date)
milestones = relationship(
'Task',
cascade='all, delete, delete-orphan',
foreign_keys='Task.container_id',
collection_class=ordering_list('position', reorder_on_append=True)
)
class Worklist(Content):
id = Column(Integer, ForeignKey(Content.id), primary_key=True)
employee = Column(Unicode(128), nullable=False)
tasks = relationship(
'Task',
cascade='all, delete, delete-orphan',
foreign_keys='Task.container_id',
collection_class=ordering_list('position', reorder_on_append=True)
)
def main():
db_session.registry.clear()
db_session.configure(bind=engine)
metadata.bind = engine
metadata.create_all(engine)
# Test basic operation
task = Task(title=u'Buy milk')
task = get_content(task.id)
# assert Content attributes inherited
assert task.title == u'Buy milk'
assert task.done == False
# add subtasks
task.tasks = [
Task(title=u'Remember to check expiration date'),
Task(title=u'Check bottle is not leaking')
]
# assert that subtasks is added and correctly ordered
task = get_content(task.id)
assert len(task.tasks) == 2
assert [(x.position, x.title) for x in task.tasks] == \
[(0, u'Remember to check expiration date'),
(1, u'Check bottle is not leaking')]
# reorder subtasks
task.tasks.insert(0, task.tasks.pop(1))
task = get_content(task.id)
assert len(task.tasks) == 2
assert [(x.position, x.title) for x in task.tasks] == \
[(0, u'Check bottle is not leaking'),
(1, u'Remember to check expiration date')]
# Test Project implementation
project = Project(title=u'My project')
milestone1 = Task(title=u'Milestone #1', description=u'First milestone')
milestone2 = Task(title=u'Milestone #2', description=u'Second milestone')
milestone1.tasks = [Task(title=u'Subtask for Milestone #1'), ]
milestone2.tasks = [Task(title=u'Subtask #1 for Milestone #2'),
Task(title=u'Subtask #2 for Milestone #2')]
project.milestones = [milestone1, milestone2]
project = get_content(project.id)
assert project.title == u'My project'
assert len(project.milestones) == 2
assert [(x.position, x.title) for x in project.milestones] == \
[(0, u'Milestone #1'), (1, u'Milestone #2')]
assert len(Task.query.all()) == 8
container = milestone1.container
assert isinstance(container, Project) == True
# Test Worklist implementation
worklist = Worklist(title=u'My worklist', employee=u'Torkel Lyng')
worklist.tasks = [
Task(title=u'Ask stackoverflow for help'),
Task(title=u'Learn SQLAlchemy')
]
worklist = get_content(worklist.id)
assert worklist.title == u'My worklist'
assert worklist.employee == u'Torkel Lyng'
assert len(worklist.tasks) == 2
assert len(Task.query.all()) == 10
assert isinstance(worklist.tasks[0].container, Worklist) == True
# Cleanup
task = Task.query.filter_by(title=u'Buy milk').one()
task.delete()
project.delete()
worklist.delete()
assert len(Task.query.all()) == 0
if __name__=='__main__':
main()
The container relationship on the Content-class did not work as expected, it returned None if I did not specify task.container = somecontainer. Instead I opted for a property method instead which returns None or container object. I'll investigate the subject further to perhaps find a more optimal solution. Suggestions or alternative solutions are still very very welcome.
To provide an activity log in my SQLAlchemy-based app, I have a model like this:
class ActivityLog(Base):
__tablename__ = 'activitylog'
id = Column(Integer, primary_key=True)
activity_by_id = Column(Integer, ForeignKey('users.id'), nullable=False)
activity_by = relation(User, primaryjoin=activity_by_id == User.id)
activity_at = Column(DateTime, default=datetime.utcnow, nullable=False)
activity_type = Column(SmallInteger, nullable=False)
target_table = Column(Unicode(20), nullable=False)
target_id = Column(Integer, nullable=False)
target_title = Column(Unicode(255), nullable=False)
The log contains entries for multiple tables, so I can't use ForeignKey relations. Log entries are made like this:
doc = Document(name=u'mydoc', title=u'My Test Document',
created_by=user, edited_by=user)
session.add(doc)
session.flush() # See note below
log = ActivityLog(activity_by=user, activity_type=ACTIVITY_ADD,
target_table=Document.__table__.name, target_id=doc.id,
target_title=doc.title)
session.add(log)
This leaves me with three problems:
I have to flush the session before my doc object gets an id. If I had used a ForeignKey column and a relation mapper, I could have simply called ActivityLog(target=doc) and let SQLAlchemy do the work. Is there any way to work around needing to flush by hand?
The target_table parameter is too verbose. I suppose I could solve this with a target property setter in ActivityLog that automatically retrieves the table name and id from a given instance.
Biggest of all, I'm not sure how to retrieve a model instance from the database. Given an ActivityLog instance log, calling self.session.query(log.target_table).get(log.target_id) does not work, as query() expects a model as parameter.
One workaround appears to be to use polymorphism and derive all my models from a base model which ActivityLog recognises. Something like this:
class Entity(Base):
__tablename__ = 'entities'
id = Column(Integer, primary_key=True)
title = Column(Unicode(255), nullable=False)
edited_at = Column(DateTime, onupdate=datetime.utcnow, nullable=False)
entity_type = Column(Unicode(20), nullable=False)
__mapper_args__ = {'polymorphic_on': entity_type}
class Document(Entity):
__tablename__ = 'documents'
__mapper_args__ = {'polymorphic_identity': 'document'}
body = Column(UnicodeText, nullable=False)
class ActivityLog(Base):
__tablename__ = 'activitylog'
id = Column(Integer, primary_key=True)
...
target_id = Column(Integer, ForeignKey('entities.id'), nullable=False)
target = relation(Entity)
If I do this, ActivityLog(...).target will give me a Document instance when it refers to a Document, but I'm not sure it's worth the overhead of having two tables for everything. Should I go ahead and do it this way?
One way to solve this is polymorphic associations. It should solve all 3 of your issues and also make database foreign key constraints work. See the polymorphic association example in SQLAlchemy source. Mike Bayer has an old blogpost that discusses this in greater detail.
Definitely go through the blogpost and examples Ants linked to. I did not find the explanation confusion, but rather assuming some more experience on the topic.
Few things I can suggest are:
ForeignKeys: in general I agree they are a good thing go have, but I am not sure it is conceptually important in your case: you seem to be using this ActivityLog as an orthogonal cross-cutting concern (AOP); but version with foreign keys would effectively make your business objects aware of the ActivityLog. Another problem with having FK for audit purposes using schema setup you have is that if you allow object deletion, FK requirement will delete all the ActivityLog entries for this object.
Automatic logging: you are doing all this logging manually whenever you create/modify(/delete) the object. With SA you could implement a SessionExtension with before_commit which would do the job for you automatically.
In this way you completely can avoid writing parts like below:
log = ActivityLog(activity_by=user, activity_type=ACTIVITY_ADD,
target_table=Document.__table__.name, target_id=doc.id,
target_title=doc.title)
session.add(log)
EDIT-1: complete sample code added
The code is based on the first non-FK version from http://techspot.zzzeek.org/?p=13.
The choice not to use FK is based on the fact that for audit purposes when the
main object is deleted, it should not cascade to delete all the audit log entries.
Also this keeps auditable objects unaware of the fact they are being audited.
Implementation uses a SA one-to-many relationship. It is possible that some
objects are modified many times, which will result in many audit log entries.
By default SA will load the relationship objects when adding a new entry to the
list. Assuming that during "normal" usage we would like only to add new audit
log entry, we use lazy='noload' flag so that the relation from the main object
will never be loaded. It is loaded when navigated from the other side though,
and also can be loaded from the main object using custom query, which is shown
in the example as well using activitylog_readonly readonly property.
Code (runnable with some tests):
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, SmallInteger, String, DateTime, ForeignKey, Table, UnicodeText, Unicode, and_
from sqlalchemy.orm import relationship, dynamic_loader, scoped_session, sessionmaker, class_mapper, backref
from sqlalchemy.orm.session import Session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.interfaces import SessionExtension
import logging
logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger()
ACTIVITY_ADD = 1
ACTIVITY_MOD = 2
ACTIVITY_DEL = 3
class ActivityLogSessionExtension(SessionExtension):
_logger = logging.getLogger('ActivityLogSessionExtension')
def before_commit(self, session):
self._logger.debug("before_commit: %s", session)
for d in session.new:
self._logger.info("before_commit >> add: %s", d)
if hasattr(d, 'create_activitylog'):
log = d.create_activitylog(ACTIVITY_ADD)
for d in session.dirty:
self._logger.info("before_commit >> mod: %s", d)
if hasattr(d, 'create_activitylog'):
log = d.create_activitylog(ACTIVITY_MOD)
for d in session.deleted:
self._logger.info("before_commit >> del: %s", d)
if hasattr(d, 'create_activitylog'):
log = d.create_activitylog(ACTIVITY_DEL)
# Configure test data SA
engine = create_engine('sqlite:///:memory:', echo=False)
session = scoped_session(sessionmaker(bind=engine, autoflush=False, extension=ActivityLogSessionExtension()))
Base = declarative_base()
Base.query = session.query_property()
class _BaseMixin(object):
""" Just a helper mixin class to set properties on object creation.
Also provides a convenient default __repr__() function, but be aware that
also relationships are printed, which might result in loading relations.
"""
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)
def __repr__(self):
return "<%s(%s)>" % (self.__class__.__name__,
', '.join('%s=%r' % (k, self.__dict__[k])
for k in sorted(self.__dict__) if '_sa_' != k[:4] and '_backref_' != k[:9])
)
class User(Base, _BaseMixin):
__tablename__ = u'users'
id = Column(Integer, primary_key=True)
name = Column(String)
class Document(Base, _BaseMixin):
__tablename__ = u'documents'
id = Column(Integer, primary_key=True)
title = Column(Unicode(255), nullable=False)
body = Column(UnicodeText, nullable=False)
class Folder(Base, _BaseMixin):
__tablename__ = u'folders'
id = Column(Integer, primary_key=True)
title = Column(Unicode(255), nullable=False)
comment = Column(UnicodeText, nullable=False)
class ActivityLog(Base, _BaseMixin):
__tablename__ = u'activitylog'
id = Column(Integer, primary_key=True)
activity_by_id = Column(Integer, ForeignKey('users.id'), nullable=False)
activity_by = relationship(User) # #note: no need to specify the primaryjoin
activity_at = Column(DateTime, default=datetime.utcnow, nullable=False)
activity_type = Column(SmallInteger, nullable=False)
target_table = Column(Unicode(20), nullable=False)
target_id = Column(Integer, nullable=False)
target_title = Column(Unicode(255), nullable=False)
# backref relation for auditable
target = property(lambda self: getattr(self, '_backref_%s' % self.target_table))
def _get_user():
""" This method returns the User object for the current user.
#todo: proper implementation required
#hack: currently returns the 'user2'
"""
return session.query(User).filter_by(name='user2').one()
# auditable support function
# based on first non-FK version from http://techspot.zzzeek.org/?p=13
def auditable(cls, name):
def create_activitylog(self, activity_type):
log = ActivityLog(activity_by=_get_user(),
activity_type=activity_type,
target_table=table.name,
target_title=self.title,
)
getattr(self, name).append(log)
return log
mapper = class_mapper(cls)
table = mapper.local_table
cls.create_activitylog = create_activitylog
def _get_activitylog(self):
return Session.object_session(self).query(ActivityLog).with_parent(self).all()
setattr(cls, '%s_readonly' %(name,), property(_get_activitylog))
# no constraints, therefore define constraints in an ad-hoc fashion.
primaryjoin = and_(
list(table.primary_key)[0] == ActivityLog.__table__.c.target_id,
ActivityLog.__table__.c.target_table == table.name
)
foreign_keys = [ActivityLog.__table__.c.target_id]
mapper.add_property(name,
# #note: because we use the relationship, by default all previous
# ActivityLog items will be loaded for an object when new one is
# added. To avoid this, use either dynamic_loader (http://www.sqlalchemy.org/docs/reference/orm/mapping.html#sqlalchemy.orm.dynamic_loader)
# or lazy='noload'. This is the trade-off decision to be made.
# Additional benefit of using lazy='noload' is that one can also
# record DEL operations in the same way as ADD, MOD
relationship(
ActivityLog,
lazy='noload', # important for relationship
primaryjoin=primaryjoin,
foreign_keys=foreign_keys,
backref=backref('_backref_%s' % table.name,
primaryjoin=list(table.primary_key)[0] == ActivityLog.__table__.c.target_id,
foreign_keys=foreign_keys)
)
)
# this will define which classes support the ActivityLog interface
auditable(Document, 'activitylogs')
auditable(Folder, 'activitylogs')
# create db schema
Base.metadata.create_all(engine)
## >>>>> TESTS >>>>>>
# create some basic data first
u1 = User(name='user1')
u2 = User(name='user2')
session.add(u1)
session.add(u2)
session.commit()
session.expunge_all()
# --check--
assert not(_get_user() is None)
##############################
## ADD
##############################
_logger.info('-' * 80)
d1 = Document(title=u'Document-1', body=u'Doc1 some body skipped the body')
# when not using SessionExtension for any reason, this can be called manually
#d1.create_activitylog(ACTIVITY_ADD)
session.add(d1)
session.commit()
f1 = Folder(title=u'Folder-1', comment=u'This folder is empty')
# when not using SessionExtension for any reason, this can be called manually
#f1.create_activitylog(ACTIVITY_ADD)
session.add(f1)
session.commit()
# --check--
session.expunge_all()
logs = session.query(ActivityLog).all()
_logger.debug(logs)
assert len(logs) == 2
assert logs[0].activity_type == ACTIVITY_ADD
assert logs[0].target.title == u'Document-1'
assert logs[0].target.title == logs[0].target_title
assert logs[1].activity_type == ACTIVITY_ADD
assert logs[1].target.title == u'Folder-1'
assert logs[1].target.title == logs[1].target_title
##############################
## MOD(ify)
##############################
_logger.info('-' * 80)
session.expunge_all()
d1 = session.query(Document).filter_by(id=1).one()
assert d1.title == u'Document-1'
assert d1.body == u'Doc1 some body skipped the body'
assert d1.activitylogs == []
d1.title = u'Modified: Document-1'
d1.body = u'Modified: body'
# when not using SessionExtension (or it does not work, this can be called manually)
#d1.create_activitylog(ACTIVITY_MOD)
session.commit()
_logger.debug(d1.activitylogs_readonly)
# --check--
session.expunge_all()
logs = session.query(ActivityLog).all()
assert len(logs)==3
assert logs[2].activity_type == ACTIVITY_MOD
assert logs[2].target.title == u'Modified: Document-1'
assert logs[2].target.title == logs[2].target_title
##############################
## DEL(ete)
##############################
_logger.info('-' * 80)
session.expunge_all()
d1 = session.query(Document).filter_by(id=1).one()
# when not using SessionExtension for any reason, this can be called manually,
#d1.create_activitylog(ACTIVITY_DEL)
session.delete(d1)
session.commit()
session.expunge_all()
# --check--
session.expunge_all()
logs = session.query(ActivityLog).all()
assert len(logs)==4
assert logs[0].target is None
assert logs[2].target is None
assert logs[3].activity_type == ACTIVITY_DEL
assert logs[3].target is None
##############################
## print all activity logs
##############################
_logger.info('=' * 80)
logs = session.query(ActivityLog).all()
for log in logs:
_ = log.target
_logger.info("%s -> %s", log, log.target)
##############################
## navigate from main object
##############################
_logger.info('=' * 80)
session.expunge_all()
f1 = session.query(Folder).filter_by(id=1).one()
_logger.info(f1.activitylogs_readonly)