Skip to content

Commit 9fdecfe

Browse files
authored
Merge pull request #514 from dimitri-yatsenko/dev
Fix #511: prohibit direct inserts outside the make callback
2 parents 0c9fac0 + e6295d0 commit 9fdecfe

File tree

6 files changed

+38
-9
lines changed

6 files changed

+38
-9
lines changed

datajoint/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class key:
4646
config.add_history('No config file found, using default settings.')
4747
else:
4848
config.load(config_file)
49+
del config_file
50+
4951
del config_files
5052

5153
# override login credentials with environment variables
@@ -58,6 +60,7 @@ class key:
5860
for k in mapping:
5961
config.add_history('Updated login credentials from %s' % k)
6062
config.update(mapping)
63+
del mapping
6164

6265
logger.setLevel(log_levels[config['loglevel']])
6366

datajoint/autopopulate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AutoPopulate:
2222
must define the property `key_source`, and must define the callback method `make`.
2323
"""
2424
_key_source = None
25+
_allow_insert = False
2526

2627
@property
2728
def key_source(self):
@@ -148,6 +149,7 @@ def handler(signum, frame):
148149
else:
149150
logger.info('Populating: ' + str(key))
150151
call_count += 1
152+
self._allow_insert = True
151153
try:
152154
make(dict(key))
153155
except (KeyboardInterrupt, SystemExit, Exception) as error:
@@ -172,6 +174,8 @@ def handler(signum, frame):
172174
self.connection.commit_transaction()
173175
if reserve_jobs:
174176
jobs.complete(self.target.table_name, self._job_key(key))
177+
finally:
178+
self._allow_insert = False
175179

176180
# place back the original signal handler
177181
if reserve_jobs:

datajoint/table.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def insert1(self, row, **kwargs):
152152
"""
153153
self.insert((row,), **kwargs)
154154

155-
def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False):
155+
def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields=False, ignore_errors=False,
156+
allow_direct_insert=None):
156157
"""
157158
Insert a collection of rows.
158159
@@ -161,6 +162,7 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
161162
:param replace: If True, replaces the existing tuple.
162163
:param skip_duplicates: If True, silently skip duplicate inserts.
163164
:param ignore_extra_fields: If False, fields that are not in the heading raise error.
165+
:param allow_direct_insert: applies only in auto-populated tables. Set True to insert outside populate calls.
164166
165167
Example::
166168
>>> relation.insert([
@@ -172,6 +174,11 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
172174
warnings.warn('Use of `ignore_errors` in `insert` and `insert1` is deprecated. Use try...except... '
173175
'to explicitly handle any errors', stacklevel=2)
174176

177+
# prohibit direct inserts into auto-populated tables
178+
if not (allow_direct_insert or getattr(self, '_allow_insert', True)): # _allow_insert is only present in AutoPopulate
179+
raise DataJointError(
180+
'Auto-populate tables can only be inserted into from their make methods during populate calls.')
181+
175182
heading = self.heading
176183
if inspect.isclass(rows) and issubclass(rows, Query): # instantiate if a class
177184
rows = rows()

tests/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ class DecimalPrimaryKey(dj.Lookup):
275275
@schema
276276
class IndexRich(dj.Manual):
277277
definition = """
278-
-> Experiment
278+
-> Subject
279279
---
280280
-> [unique, nullable] User.proj(first="username")
281281
first_date : date

tests/test_autopopulate.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
from nose.tools import assert_raises, assert_equal, \
2-
assert_false, assert_true, assert_list_equal, \
3-
assert_tuple_equal, assert_dict_equal, raises
4-
1+
from nose.tools import assert_equal, assert_false, assert_true, raises
52
from . import schema
3+
from datajoint import DataJointError
64

75

86
class TestPopulate:
97
"""
108
Test base relations: insert, delete
119
"""
1210

13-
def __init__(self):
11+
def setUp(self):
1412
self.user = schema.User()
1513
self.subject = schema.Subject()
1614
self.experiment = schema.Experiment()
1715
self.trial = schema.Trial()
1816
self.ephys = schema.Ephys()
1917
self.channel = schema.Ephys.Channel()
2018

19+
def tearDown(self):
2120
# delete automatic tables just in case
2221
self.channel.delete_quick()
2322
self.ephys.delete_quick()
23+
self.trial.Condition.delete_quick()
2424
self.trial.delete_quick()
2525
self.experiment.delete_quick()
2626

@@ -49,3 +49,18 @@ def test_populate(self):
4949
self.ephys.populate()
5050
assert_true(self.ephys)
5151
assert_true(self.channel)
52+
53+
def test_allow_direct_insert(self):
54+
assert_true(self.subject, 'root tables are empty')
55+
key = self.subject.fetch('KEY')[0]
56+
key['experiment_id'] = 1000
57+
key['experiment_date'] = '2018-10-30'
58+
self.experiment.insert1(key, allow_direct_insert=True)
59+
60+
@raises(DataJointError)
61+
def test_allow_insert(self):
62+
assert_true(self.subject, 'root tables are empty')
63+
key = self.subject.fetch('KEY')[0]
64+
key['experiment_id'] = 1001
65+
key['experiment_date'] = '2018-10-30'
66+
self.experiment.insert1(key)

tests/test_declare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def test_dependencies():
105105
assert_true(experiment.full_table_name in set(user.children(primary=False)))
106106
assert_equal(set(experiment.parents(primary=False)), {user.full_table_name})
107107

108-
assert_equal(set(subject.children(primary=True)), {experiment.full_table_name})
109-
assert_equal(set(experiment.parents(primary=True)), {subject.full_table_name})
108+
assert_true(experiment.full_table_name in subject.descendants())
109+
assert_true(subject.full_table_name in experiment.ancestors())
110110

111111
assert_true(trial.full_table_name in experiment.descendants())
112112
assert_true(experiment.full_table_name in trial.ancestors())

0 commit comments

Comments
 (0)