Skip to content

Commit

Permalink
ORM: Use skip_orm as the default implementation for `SqlaGroup.add_…
Browse files Browse the repository at this point in the history
…nodes` and `SqlaGroup.remove_nodes` (#6720)

This commit removes the `skip_orm` flag and adopts the optimized approach behind it as the default behavior for the `add_nodes` and `remove_nodes` methods.

As discussed in #5453, the ORM-based implementations of `add_nodes` and `remove_nodes` were found to be inefficient, and to address this, the `skip_orm` flag was introduced. This flag enabled a faster, core API-based approach for these operations, but was never well documented or advertised.

All existing tests pass with the new implementation, and are sufficiently comprehensive to ensure its safety. Consequently, it is a reasonable decision to make the `skip_orm` implementation the default for bulk insertion and removal of elements within a group.
  • Loading branch information
rabbull authored Feb 10, 2025
1 parent f56fcc3 commit d2fbf21
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 122 deletions.
74 changes: 16 additions & 58 deletions src/aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,10 @@ def add_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used
to create a direct SQL INSERT statement to the group-node relationship
table (to improve speed).
"""
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError

super().add_nodes(nodes)
skip_orm = kwargs.get('skip_orm', False)

def check_node(given_node):
"""Check if given node is of correct type and stored"""
Expand All @@ -188,31 +181,16 @@ def check_node(given_node):
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes

for node in nodes:
check_node(node)

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.bare_model)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
else:
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})
if len(ins_dict) == 0:
return

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
if not session.in_nested_transaction():
Expand All @@ -224,45 +202,25 @@ def remove_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL
DELETE statement to the group-node relationship table in order to improve speed.
"""
from sqlalchemy import and_

super().remove_nodes(nodes)

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
skip_orm = kwargs.get('skip_orm', False)

def check_node(node):
if not isinstance(node, self.NODE_CLASS):
raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}')

if node.id is None:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

list_nodes = []

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)

# Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error
if node.bare_model in dbnodes:
list_nodes.append(node.bare_model)

for node in list_nodes:
dbnodes.remove(node)
else:
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)

if not session.in_nested_transaction():
session.commit()
Expand Down
63 changes: 0 additions & 63 deletions tests/orm/implementation/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,3 @@ def test_creation_from_dbgroup(backend):

assert group.pk == gcopy.pk
assert group.uuid == gcopy.uuid


def test_add_nodes_skip_orm():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag."""
group = orm.Group(label='test_adding_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
node_05 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03, node_04, node_05]

group.add_nodes([node_01], skip_orm=True)
group.add_nodes([node_02, node_03], skip_orm=True)
group.add_nodes((node_04, node_05), skip_orm=True)

assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add a node that is already present: there should be no problem
group.add_nodes([node_01], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_add_nodes_skip_orm_batch():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches."""
nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_remove_nodes_bulk():
"""Test node removal with `skip_orm=True`."""
group = orm.Group(label='test_removing_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03]

group.add_nodes(nodes)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a node that is not in the group: nothing should happen
group.remove_nodes([node_04], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove one Node
nodes.remove(node_03)
group.remove_nodes([node_03], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a list of Nodes and check
nodes.remove(node_01)
nodes.remove(node_02)
group.remove_nodes([node_01, node_02], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)
25 changes: 24 additions & 1 deletion tests/orm/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,27 @@ def test_add_nodes(self):
group.add_nodes(node_01)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add nothing: there should be no problem
group.add_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_remove_nodes(self):
"""Test node removal."""
node_01 = orm.Data().store()
node_02 = orm.Data().store()
node_03 = orm.Data().store()
node_04 = orm.Data().store()
nodes = [node_01, node_02, node_03]
node_05 = orm.Data().store()
nodes = [node_01, node_02, node_03, node_05]
group = orm.Group(label=uuid.uuid4().hex).store()

# Add initial nodes
Expand All @@ -177,6 +191,15 @@ def test_remove_nodes(self):
group.remove_nodes([node_01, node_02])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove to empty
nodes.remove(node_05)
group.remove_nodes([node_05])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to remove nothing: there should be no problem
group.remove_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_clear(self):
"""Test the `clear` method to remove all nodes."""
node_01 = orm.Data().store()
Expand Down

0 comments on commit d2fbf21

Please sign in to comment.