Skip to content

Commit

Permalink
fixes related instances bug (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
bschreck authored and kmax12 committed Feb 20, 2018
1 parent 98eab54 commit e6390c1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
21 changes: 9 additions & 12 deletions featuretools/entityset/entityset.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,18 +983,15 @@ def _related_instances(self, start_entity_id, final_entity_id,
training_window_is_dict = isinstance(training_window, dict)
window = training_window
start_estore = self.entity_stores[start_entity_id]
if instance_ids is None:
df = start_estore.df
else: # instance_ids was passed in
# This check might be brittle
if not hasattr(instance_ids, '__iter__'):
instance_ids = [instance_ids]

if training_window_is_dict:
window = training_window.get(start_estore.id)
df = start_estore.query_by_values(instance_ids,
time_last=time_last,
training_window=window)
# This check might be brittle
if instance_ids is not None and not hasattr(instance_ids, '__iter__'):
instance_ids = [instance_ids]

if training_window_is_dict:
window = training_window.get(start_estore.id)
df = start_estore.query_by_values(instance_vals=instance_ids,
time_last=time_last,
training_window=window)
# if we're querying on a path that's not actually a path, just return
# the relevant slice of the entityset
if start_entity_id == final_entity_id:
Expand Down
8 changes: 8 additions & 0 deletions featuretools/tests/entityset_tests/test_pandas_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,14 @@ def test_related_instances_all(self, entityset):
for p in entityset.get_column_data('products', 'id').values:
assert p in result['id'].values

def test_related_instances_all_cutoff_time_same_entity(self, entityset):
# test querying across the entityset
result = entityset._related_instances(
start_entity_id='log', final_entity_id='log',
instance_ids=None, time_last=pd.Timestamp('2011/04/09 10:30:31'))

assert result['id'].values.tolist() == list(range(5))

def test_related_instances_link_vars(self, entityset):
# test adding link variables on the fly during _related_instances
frame = entityset._related_instances(
Expand Down

0 comments on commit e6390c1

Please sign in to comment.