11from PIL import Image , ImageDraw
22from imageio import imread
33import numpy as np
4+ import traceback
5+ import sys
6+ from threading import Thread
7+ from concurrent .futures import ThreadPoolExecutor
8+
49
510class DiffgramDatasetIterator :
611
7- def __init__ (self , project , diffgram_file_id_list , validate_ids = True ):
12+ def __init__ (self , project ,
13+ diffgram_file_id_list ,
14+ validate_ids = True ,
15+ max_size_cache = 1073741824 ,
16+ max_num_concurrent_fetches = 25 ):
817 """
918
1019 :param project (sdk.core.core.Project): A Project object from the Diffgram SDK
1120 :param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
1221 """
1322 self .diffgram_file_id_list = diffgram_file_id_list
14-
23+ self .max_size_cache = max_size_cache
24+ self .pool = ThreadPoolExecutor (max_num_concurrent_fetches )
1525 self .project = project
26+ self .file_cache = {}
1627 self ._internal_file_list = []
1728 if validate_ids :
1829 self .__validate_file_ids ()
@@ -25,22 +36,58 @@ def __iter__(self):
2536 def __len__ (self ):
2637 return len (self .diffgram_file_id_list )
2738
28- def __getitem__ (self , idx ):
29- diffgram_file = self .project .file .get_by_id (self .diffgram_file_id_list [idx ], with_instances = True )
39+ def save_file_in_cache (self , idx , instance_data ):
40+ # If size of cache greater than 1GB (Default)
41+ if sys .getsizeof (self .file_cache ) > self .max_size_cache :
42+ keys = list (self .file_cache .keys ())
43+ latest_keys = keys [:- 10 ] # Get oldest 10 elements
44+ for k in latest_keys :
45+ self .file_cache .pop (k )
46+
47+ self .file_cache [idx ] = instance_data
48+
49+ def get_next_n_items (self , idx , num_items = 25 ):
50+ """
51+ Get next N items and save them to cache proactively.
52+ :param idx:
53+ :param n:
54+ :return:
55+ """
56+ latest_index = idx + num_items
57+ if latest_index >= len (self .diffgram_file_id_list ):
58+ latest_index = len (self .diffgram_file_id_list )
59+
60+ for i in range (idx + 1 , latest_index ):
61+ self .pool .submit (self .__get_file_data_for_index , (i ,))
62+ return True
63+
64+ def __get_file_data_for_index (self , idx ):
65+ diffgram_file = self .project .file .get_by_id (self .diffgram_file_id_list [idx ], with_instances = True , use_session = False )
3066 instance_data = self .get_file_instances (diffgram_file )
67+ self .save_file_in_cache (idx , instance_data )
3168 return instance_data
3269
70+ def __getitem__ (self , idx ):
71+ if self .file_cache .get (idx ):
72+ return self .file_cache .get (idx )
73+
74+ result = self .__get_file_data_for_index (idx )
75+
76+ self .get_next_n_items (idx , num_items = 25 )
77+
78+ return result
79+
3380 def __next__ (self ):
34- file_id = self .diffgram_file_id_list [ self .current_file_index ]
35- diffgram_file = self .project . file . get_by_id ( file_id , with_instances = True )
36- instance_data = self .get_file_instances ( diffgram_file )
81+ if self .file_cache . get ( self .current_file_index ):
82+ return self .file_cache . get ( self . current_file_index )
83+ instance_data = self .__get_file_data_for_index ( self . current_file_index )
3784 self .current_file_index += 1
3885 return instance_data
3986
4087 def __validate_file_ids (self ):
4188 if not self .diffgram_file_id_list :
4289 return
43- result = self .project .file .file_list_exists (self .diffgram_file_id_list )
90+ result = self .project .file .file_list_exists (self .diffgram_file_id_list , use_session = False )
4491 if not result :
4592 raise Exception (
4693 'Some file IDs do not belong to the project. Please provide only files from the same project.' )
@@ -56,7 +103,9 @@ def get_image_data(self, diffgram_file):
56103 if i < MAX_RETRIES - 1 :
57104 continue
58105 else :
59- raise e
106+ print ('Fetch Image Failed: Diffgram File ID: {}' .format (diffgram_file .id ))
107+ print (traceback .format_exc ())
108+ return None
60109 return image
61110 else :
62111 raise Exception ('Pytorch datasets only support images. Please provide only file_ids from images' )
0 commit comments