Skip to content

Commit

Permalink
db: upgrade flask-sqlalchemy
Browse files Browse the repository at this point in the history
* replaces the previous public method `create_scoped_session` from
  flask-sqlalchemy with the direct way of creating a new session with
  SQLAlchemy. Related GitHub issue:
  jeancochrane/pytest-flask-sqlalchemy#63
  • Loading branch information
ntarocco committed Apr 12, 2023
1 parent d7397f3 commit 5b1eb35
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 19 deletions.
20 changes: 6 additions & 14 deletions pytest_invenio/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,25 +485,17 @@ def db(database):
the test (this is much faster than recreating the entire database).
"""
import sqlalchemy as sa
from random import randrange

connection = database.engine.connect()
transaction = connection.begin()
old_session = database.session

# Create a new session and assign it to the current db session.
# The new session is scoped and bound to the `connection`.
options = dict(bind=connection, binds={})
session = database.create_scoped_session(options=options)

session.begin_nested()

# `session` is actually a scoped_session. For the `after_transaction_end`
# event, we need a session instance to listen for, hence the `session()`
# call.
@sa.event.listens_for(session(), "after_transaction_end")
def restart_savepoint(sess, trans):
if trans.nested and not trans._parent.nested:
session.expire_all()
session.begin_nested()

old_session = database.session
session_factory = sa.orm.sessionmaker(**options)
session = sa.orm.scoped_session(session_factory)
database.session = session

yield database
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ install_requires =
tests =
pytest-black>=0.3.0
invenio-celery>=1.2.4,<2.0.0
invenio-db>=1.0.12,<2.0.0
invenio-db>=1.1.0,<2.0.0
invenio-files-rest>=1.3.2,<2.0.0
invenio-mail>=1.0.2,<2.0.0
invenio-search>=2.1.0,<3.0.0
Expand Down
8 changes: 4 additions & 4 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,11 @@ def test_db(conftest_testdir):
"""Test database creation and initialization."""
conftest_testdir.makepyfile(
"""
from invenio_db import db
from invenio_db import db as _db
class UserA(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True)
class UserA(_db.Model):
id = _db.Column(_db.Integer, primary_key=True)
username = _db.Column(_db.String(80), unique=True)
def test_db1(db):
assert UserA.query.count() == 0
Expand Down

0 comments on commit 5b1eb35

Please sign in to comment.