@@ -35,33 +35,56 @@ class SuperpixelClassificationTensorflow(SuperpixelClassificationBase):
3535 def __init__ (self ):
3636 self .training_optimal_batchsize : Optional [int ] = None
3737 self .prediction_optimal_batchsize : Optional [int ] = None
38+ self .use_cuda = False
3839
3940 def trainModelDetails (self , record , annotationName , batchSize , epochs , itemsAndAnnot , prog ,
40- tempdir , trainingSplit ):
41- # print(f'Tensorflow trainModelDetails(batchSize={batchSize}, ...)')
42- # make model
43- num_classes = len (record ['labels' ])
44- model = tf .keras .Sequential ([
45- tf .keras .layers .Rescaling (1.0 / 255 ),
46- tf .keras .layers .Conv2D (16 , 3 , padding = 'same' , activation = 'relu' ),
47- tf .keras .layers .MaxPooling2D (),
48- tf .keras .layers .Conv2D (32 , 3 , padding = 'same' , activation = 'relu' ),
49- tf .keras .layers .MaxPooling2D (),
50- tf .keras .layers .Conv2D (64 , 3 , padding = 'same' , activation = 'relu' ),
51- tf .keras .layers .MaxPooling2D (),
52- tf .keras .layers .Flatten (),
53- # tf.keras.layers.Dropout(0.2),
54- tf .keras .layers .Dense (128 , activation = 'relu' ),
55- tf .keras .layers .Dense (num_classes )])
56- prog .progress (0.2 )
57- model .compile (optimizer = 'adam' ,
58- loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
59- metrics = ['accuracy' ])
41+ tempdir , trainingSplit , use_cuda ):
42+ self .use_cuda = use_cuda
43+
44+ # Enable GPU memory growth globally to avoid precondition errors
45+ gpus = tf .config .list_physical_devices ('GPU' )
46+ if gpus and self .use_cuda :
47+ try :
48+ for gpu in gpus :
49+ tf .config .experimental .set_memory_growth (gpu , True )
50+ except RuntimeError as e :
51+ print (f"Could not set memory growth: { e } " )
52+ if not self .use_cuda :
53+ tf .config .set_visible_devices ([], 'GPU' )
54+ device = "gpu" if use_cuda else "cpu"
55+ print (f"Using device: { device } " )
56+
57+ # Dataset preparation (outside strategy scope)
58+ ds_h5 = record ['ds' ]
59+ labelds_h5 = record ['labelds' ]
60+ # Fully load to memory and break h5py reference
61+ ds_numpy = np .array (ds_h5 [:])
62+ labelds_numpy = np .array (labelds_h5 [:])
63+
64+ strategy = tf .distribute .MirroredStrategy ()
65+ with strategy .scope ():
66+ num_classes = len (record ['labels' ])
67+ model = tf .keras .Sequential ([
68+ tf .keras .layers .Rescaling (1.0 / 255 ),
69+ tf .keras .layers .Conv2D (16 , 3 , padding = 'same' , activation = 'relu' ),
70+ tf .keras .layers .MaxPooling2D (),
71+ tf .keras .layers .Conv2D (32 , 3 , padding = 'same' , activation = 'relu' ),
72+ tf .keras .layers .MaxPooling2D (),
73+ tf .keras .layers .Conv2D (64 , 3 , padding = 'same' , activation = 'relu' ),
74+ tf .keras .layers .MaxPooling2D (),
75+ tf .keras .layers .Flatten (),
76+ tf .keras .layers .Dense (128 , activation = 'relu' ),
77+ tf .keras .layers .Dense (num_classes )])
78+ prog .progress (0.2 )
79+ model .compile (optimizer = 'adam' ,
80+ loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
81+ metrics = ['accuracy' ])
82+
6083 prog .progress (0.7 )
61- # generate split
62- full_ds = tf .data .Dataset .from_tensor_slices ((record [ 'ds' ], record [ 'labelds' ] ))
63- full_ds = full_ds .shuffle (1000 ) # add seed=123 ?
64- count = len (full_ds )
84+ # generate split using numpy arrays
85+ full_ds = tf .data .Dataset .from_tensor_slices ((ds_numpy , labelds_numpy ))
86+ full_ds = full_ds .shuffle (1000 )
87+ count = len (ds_numpy )
6588 train_size = int (count * trainingSplit )
6689 if batchSize < 1 :
6790 batchSize = self .findOptimalBatchSize (model , full_ds , training = True )
@@ -85,24 +108,53 @@ def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndA
85108 self .saveModel (model , modelPath )
86109 return history , modelPath
87110
111+ def _get_device (self , use_cuda ):
112+ if tf .config .list_physical_devices ('GPU' ) and use_cuda :
113+ return '/GPU:0'
114+ return '/CPU:0'
115+
88116 def predictLabelsForItemDetails (
89- self , batchSize , ds : h5py ._hl .dataset .Dataset , item , model , prog ,
117+ self , batchSize , ds : h5py ._hl .dataset .Dataset , indices , item , model , use_cuda , prog ,
90118 ):
91- # print(f'Tensorflow predictLabelsForItemDetails(batchSize={batchSize}, ...)')
92119 if batchSize < 1 :
93120 batchSize = self .findOptimalBatchSize (
94121 model , tf .data .Dataset .from_tensor_slices (ds ), training = False ,
95122 )
96123 print (f'Optimal batch size for prediction = { batchSize } ' )
97- predictions = model .predict (
98- ds ,
99- batch_size = batchSize ,
100- callbacks = [_LogTensorflowProgress (
101- prog , (ds .shape [0 ] + batchSize - 1 ) // batchSize , 0.05 , 0.35 , item )])
102- prog .item_progress (item , 0.4 )
103- # softmax to scale to 0 to 1
104- catWeights = tf .nn .softmax (predictions )
105- return catWeights , predictions
124+
125+ device = self ._get_device (use_cuda )
126+ with tf .device (device ):
127+ # Create a dataset that pairs the data with their indices
128+ dataset = tf .data .Dataset .from_tensor_slices ((ds , indices ))
129+ dataset = dataset .batch (batchSize )
130+
131+ # Initialize arrays to store results
132+ all_predictions = []
133+ all_cat_weights = []
134+ all_indices = []
135+
136+ # Iterate through batches manually to keep track of indices
137+ for data , batch_indices in dataset :
138+ batch_predictions = model .predict (
139+ data ,
140+ batch_size = batchSize ,
141+ verbose = 0 ) # Set verbose=0 to avoid multiple progress bars
142+
143+ # Apply softmax to scale to 0 to 1
144+ batch_cat_weights = tf .nn .softmax (batch_predictions )
145+
146+ all_predictions .append (batch_predictions )
147+ all_cat_weights .append (batch_cat_weights )
148+ all_indices .append (batch_indices )
149+
150+ prog .item_progress (item , 0.4 )
151+
152+ # Concatenate all results
153+ predictions = tf .concat (all_predictions , axis = 0 )
154+ catWeights = tf .concat (all_cat_weights , axis = 0 )
155+ final_indices = tf .concat (all_indices , axis = 0 )
156+
157+ return catWeights .numpy (), predictions .numpy (), final_indices .numpy ().astype (np .int64 )
106158
107159 def findOptimalBatchSize (self , model , ds , training ) -> int :
108160 if training and self .training_optimal_batchsize is not None :
0 commit comments