diff --git a/README.rst b/README.rst index b079b45..e84420a 100644 --- a/README.rst +++ b/README.rst @@ -40,7 +40,7 @@ Add mixin to model .. code-block:: python from sqlalchemy import Column, Integer, Boolean - from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import declarative_base from sqlalchemy_mptt.mixins import BaseNestedSets diff --git a/docs/initialize.rst b/docs/initialize.rst index 1777e67..f09a7a0 100644 --- a/docs/initialize.rst +++ b/docs/initialize.rst @@ -7,7 +7,7 @@ Create model with MPTT mixin: :linenos: from sqlalchemy import Column, Integer, Boolean - from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import declarative_base from sqlalchemy_mptt.mixins import BaseNestedSets diff --git a/requirements.txt b/requirements.txt index 12e368a..568a359 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -SQLAlchemy>=1.0.0 +SQLAlchemy>=1.4 diff --git a/sqlalchemy_mptt/events.py b/sqlalchemy_mptt/events.py index d903d33..821b808 100644 --- a/sqlalchemy_mptt/events.py +++ b/sqlalchemy_mptt/events.py @@ -41,7 +41,7 @@ def _insert_subtree( delta_rgt = delta_lft + node_size - 1 connection.execute( - table.update( + table.update().where( table_pk.in_(subtree) ).values( lft=table.c.lft - node_pos_left + delta_lft, @@ -53,7 +53,7 @@ def _insert_subtree( # step 2: update key of right side connection.execute( - table.update( + table.update().where( and_( table.c.rgt > delta_lft - 1, table_pk.notin_(subtree), @@ -62,12 +62,10 @@ def _insert_subtree( ).values( rgt=table.c.rgt + node_size, lft=case( - [ - ( - table.c.lft > left_sibling['lft'], - table.c.lft + node_size - ) - ], + ( + table.c.lft > left_sibling['lft'], + table.c.lft + node_size + ), else_=table.c.lft ) ) @@ -94,9 +92,7 @@ def mptt_before_insert(mapper, connection, instance): instance.level = instance.get_default_level() tree_id = connection.scalar( select( - [ - func.max(table.c.tree_id) + 1 - ] + func.max(table.c.tree_id) + 1 ) ) or 1 instance.tree_id = tree_id @@ -106,12 +102,10 @@ def mptt_before_insert(mapper, connection, instance): parent_tree_id, parent_level) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.tree_id, - table.c.level - ] + table.c.lft, + table.c.rgt, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -119,26 +113,22 @@ def mptt_before_insert(mapper, connection, instance): # Update key of right side connection.execute( - table.update( + table.update().where( and_(table.c.rgt >= parent_pos_right, table.c.tree_id == parent_tree_id) ).values( lft=case( - [ - ( - table.c.lft > parent_pos_right, - table.c.lft + 2 - ) - ], + ( + table.c.lft > parent_pos_right, + table.c.lft + 2 + ), else_=table.c.lft ), rgt=case( - [ - ( - table.c.rgt >= parent_pos_right, - table.c.rgt + 2 - ) - ], + ( + table.c.rgt >= parent_pos_right, + table.c.rgt + 2 + ), else_=table.c.rgt ) ) @@ -158,10 +148,8 @@ def mptt_before_delete(mapper, connection, instance, delete=True): table_pk = getattr(table.c, db_pk.name) lft, rgt = connection.execute( select( - [ - table.c.lft, - table.c.rgt - ] + table.c.lft, + table.c.rgt ).where( table_pk == pk ) @@ -171,7 +159,7 @@ def mptt_before_delete(mapper, connection, instance, delete=True): if delete: mapper.base_mapper.confirm_deleted_rows = False connection.execute( - table.delete( + table.delete().where( table_pk == pk ) ) @@ -190,28 +178,24 @@ def mptt_before_delete(mapper, connection, instance, delete=True): END """ connection.execute( - table.update( + table.update().where( and_( table.c.rgt > rgt, table.c.tree_id == tree_id ) ).values( lft=case( - [ - ( - table.c.lft > lft, - table.c.lft - delta - ) - ], + ( + table.c.lft > lft, + table.c.lft - delta + ), else_=table.c.lft ), rgt=case( - [ - ( - table.c.rgt >= rgt, - table.c.rgt - delta - ) - ], + ( + table.c.rgt >= rgt, + table.c.rgt - delta + ), else_=table.c.rgt ) ) @@ -243,25 +227,21 @@ def mptt_before_update(mapper, connection, instance): right_sibling_tree_id ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.level, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.level, + table.c.tree_id ).where( table_pk == instance.mptt_move_before ) ).fetchone() current_lvl_nodes = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.tree_id ).where( and_( table.c.level == right_sibling_level, @@ -296,12 +276,10 @@ def mptt_before_update(mapper, connection, instance): left_sibling_tree_id ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.parent_id, - table.c.tree_id - ] + table.c.lft, + table.c.rgt, + table.c.parent_id, + table.c.tree_id ).where( table_pk == instance.mptt_move_after ) @@ -320,7 +298,7 @@ def mptt_before_update(mapper, connection, instance): ORDER BY left_key """ subtree = connection.execute( - select([table_pk]) + select(table_pk) .where( and_( table.c.lft >= instance.left, @@ -345,13 +323,11 @@ def mptt_before_update(mapper, connection, instance): node_level ) = connection.execute( select( - [ - table.c.lft, - table.c.rgt, - table.c.tree_id, - table.c.parent_id, - table.c.level - ] + table.c.lft, + table.c.rgt, + table.c.tree_id, + table.c.parent_id, + table.c.level ).where( table_pk == node_id ) @@ -375,13 +351,11 @@ def mptt_before_update(mapper, connection, instance): parent_level ) = connection.execute( select( - [ - table_pk, - table.c.rgt, - table.c.lft, - table.c.tree_id, - table.c.level - ] + table_pk, + table.c.rgt, + table.c.lft, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -405,13 +379,11 @@ def mptt_before_update(mapper, connection, instance): parent_level ) = connection.execute( select( - [ - table_pk, - table.c.rgt, - table.c.lft, - table.c.tree_id, - table.c.level - ] + table_pk, + table.c.rgt, + table.c.lft, + table.c.tree_id, + table.c.level ).where( table_pk == instance.parent_id ) @@ -449,7 +421,7 @@ def mptt_before_update(mapper, connection, instance): if left_sibling_tree_id or left_sibling_tree_id == 0: tree_id = left_sibling_tree_id + 1 connection.execute( - table.update( + table.update().where( table.c.tree_id > left_sibling_tree_id ).values( tree_id=table.c.tree_id + 1 @@ -459,14 +431,12 @@ def mptt_before_update(mapper, connection, instance): else: tree_id = connection.scalar( select( - [ - func.max(table.c.tree_id) + 1 - ] + func.max(table.c.tree_id) + 1 ) ) connection.execute( - table.update( + table.update().where( table_pk.in_( subtree ) diff --git a/sqlalchemy_mptt/mixins.py b/sqlalchemy_mptt/mixins.py index 5cdda2d..d1fb0d5 100644 --- a/sqlalchemy_mptt/mixins.py +++ b/sqlalchemy_mptt/mixins.py @@ -29,8 +29,7 @@ class BaseNestedSets(object): .. code:: from sqlalchemy import Boolean, Column, create_engine, Integer - from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.orm import sessionmaker + from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy_mptt.mixins import BaseNestedSets @@ -65,7 +64,11 @@ def get_pk_name(cls): @classmethod def get_pk_column(cls): - return getattr(cls, cls.get_pk_name()) + col = getattr(cls, cls.get_pk_name()) + # might be a Mapped column + if hasattr(col, "column") and hasattr(col.column, "name"): + return col.column + return col def get_pk_value(self): return getattr(self, self.get_pk_name()) @@ -342,7 +345,7 @@ def drilldown_tree(self, session=None, json=False, json_fields=None): ) def path_to_root(self, session=None, order=desc): - """Generate path from a leaf or intermediate node to the root. + r"""Generate path from a leaf or intermediate node to the root. For example: @@ -372,7 +375,7 @@ def path_to_root(self, session=None, order=desc): return self._base_order(query, order=order) def get_siblings(self, include_self=False, session=None): - """ + r""" * https://github.com/uralbash/sqlalchemy_mptt/issues/64 * https://django-mptt.readthedocs.io/en/latest/models.html#get-siblings-include-self-false @@ -414,7 +417,7 @@ def get_siblings(self, include_self=False, session=None): return query def get_children(self, session=None): - """ + r""" * https://github.com/uralbash/sqlalchemy_mptt/issues/64 * https://github.com/django-mptt/django-mptt/blob/fd76a816e05feb5fb0fc23126d33e514460a0ead/mptt/models.py#L563 diff --git a/sqlalchemy_mptt/tests/test_events.py b/sqlalchemy_mptt/tests/test_events.py index e4bb734..a402604 100644 --- a/sqlalchemy_mptt/tests/test_events.py +++ b/sqlalchemy_mptt/tests/test_events.py @@ -14,8 +14,7 @@ from sqlalchemy import Column, Boolean, Integer, create_engine from sqlalchemy.event import contains -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy_mptt import mptt_sessionmaker diff --git a/sqlalchemy_mptt/tests/test_inheritance.py b/sqlalchemy_mptt/tests/test_inheritance.py index 855a3f2..947f89e 100644 --- a/sqlalchemy_mptt/tests/test_inheritance.py +++ b/sqlalchemy_mptt/tests/test_inheritance.py @@ -1,8 +1,7 @@ import unittest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, declarative_base from . import TreeTestingMixin from ..mixins import BaseNestedSets @@ -59,7 +58,7 @@ def test_create_generic(self): self.session.add(GenericTree(ppk=1)) self.session.commit() - tree = self.session.query(GenericTree).get(1) + tree = self.session.get(GenericTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -67,7 +66,7 @@ def test_create_spec(self): self.session.add(SpecializedTree(ppk=1)) self.session.commit() - tree = self.session.query(SpecializedTree).get(1) + tree = self.session.get(SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) @@ -83,21 +82,21 @@ def test_create_delete(self): self.session.add(parent) self.session.commit() - tree = self.session.query(SpecializedTree).get(1) + tree = self.session.get(SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) self.session.delete(child1) self.session.commit() - self.assertEquals(None, self.session.query(SpecializedTree).get(2)) + self.assertEquals(None, self.session.get(SpecializedTree, 2)) self.session.delete(child2) self.session.commit() - self.assertEquals(None, self.session.query(SpecializedTree).get(3)) - self.assertEquals(None, self.session.query(SpecializedTree).get(4)) - self.assertEquals(None, self.session.query(SpecializedTree).get(5)) + self.assertEquals(None, self.session.get(SpecializedTree, 3)) + self.assertEquals(None, self.session.get(SpecializedTree, 4)) + self.assertEquals(None, self.session.get(SpecializedTree, 5)) class TestGenericTree(TreeTestingMixin, unittest.TestCase): diff --git a/sqlalchemy_mptt/tests/test_mixins.py b/sqlalchemy_mptt/tests/test_mixins.py index 0bc1bf9..0868a3b 100644 --- a/sqlalchemy_mptt/tests/test_mixins.py +++ b/sqlalchemy_mptt/tests/test_mixins.py @@ -12,7 +12,7 @@ import unittest from sqlalchemy import Column, Integer -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from ..mixins import BaseNestedSets diff --git a/tox.ini b/tox.ini index b787b9a..84234d3 100644 --- a/tox.ini +++ b/tox.ini @@ -6,3 +6,6 @@ deps= -rrequirements.txt nose commands=nosetests +setenv = + SQLALCHEMY_WARN_20=1 + PYTHONWARNINGS=always::DeprecationWarning