Skip to content

Commit ca1628d

Browse files
committed
Make using CUDA an optional parameter
This also makes it easier to run tests in CPU-only environments
1 parent 88316c4 commit ca1628d

File tree

4 files changed

+109
-46
lines changed

4 files changed

+109
-46
lines changed

superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@
100100
<label>Train model</label>
101101
<default>true</default>
102102
</boolean>
103+
<boolean>
104+
<name>useCuda</name>
105+
<longflag>usecuda</longflag>
106+
<description>Whether or not to use GPU/cuda (true) or cpu (false).</description>
107+
<label>Use CUDA</label>
108+
<default>true</default>
109+
</boolean>
103110
<integer>
104111
<name>batchSize</name>
105112
<longflag>batchsize</longflag>

superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature,
505505

506506
def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
507507
batchSize, epochs, trainingSplit, randomInput, labelList,
508-
excludeLabelList, prog):
508+
excludeLabelList, use_cuda, prog):
509509
itemsAndAnnot = self.getItemsAndAnnotations(gc, folderId, annotationName)
510510
with tempfile.TemporaryDirectory(dir=os.getcwd()) as tempdir:
511511
trainingPath = os.path.join(tempdir, 'training.h5')
@@ -544,7 +544,7 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
544544
prog.progress(0)
545545
history, modelPath = self.trainModelDetails(
546546
record, annotationName, batchSize, epochs, itemsAndAnnot, prog, tempdir,
547-
trainingSplit)
547+
trainingSplit, use_cuda)
548548

549549
modTrainingPath = os.path.join(tempdir, '%s ModTraining Epoch %d.h5' % (
550550
annotationName, self.getCurrentEpoch(itemsAndAnnot)))
@@ -568,7 +568,7 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
568568

569569
def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir, model, item,
570570
annotrec, elem, feature, curEpoch, userId, labels, groups,
571-
makeHeatmaps, radius, magnification, certainty, batchSize, prog):
571+
makeHeatmaps, radius, magnification, certainty, batchSize, use_cuda, prog):
572572
import al_bench.factory
573573

574574
print('Predicting %s' % (item['name']))
@@ -771,7 +771,7 @@ def makeHeatmapsForItem(self, gc, annotationName, userId, tempdir, radius, item,
771771

772772
def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
773773
annotationFolderId, saliencyMaps, radius, magnification,
774-
certainty, batchSize, prog):
774+
certainty, batchSize, use_cuda, prog):
775775
itemsAndAnnot = self.getItemsAndAnnotations(gc, folderId, annotationName)
776776
curEpoch = self.getCurrentEpoch(itemsAndAnnot)
777777
folder = gc.getFolder(folderId)
@@ -833,7 +833,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
833833
self.predictLabelsForItem(
834834
gc, annotationName, annotationFolderId, tempdir, model, item, annotrec, elem,
835835
features.get(item['_id']), curEpoch, userId, labels, groups, saliencyMaps,
836-
radius, magnification, certainty, batchSize, prog)
836+
radius, magnification, certainty, batchSize, use_cuda, prog)
837837
prog.progress(1)
838838

839839
def main(self, args):
@@ -864,5 +864,5 @@ def main(self, args):
864864

865865
self.predictLabels(
866866
gc, args.images, args.annotationName, features, args.modeldir, args.annotationDir,
867-
args.heatmaps, args.radius, args.magnification, args.certainty, args.batchSize,
867+
args.heatmaps, args.radius, args.magnification, args.certainty, args.batchSize, args.useCuda,
868868
prog)

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,56 @@ class SuperpixelClassificationTensorflow(SuperpixelClassificationBase):
3535
def __init__(self):
3636
self.training_optimal_batchsize: Optional[int] = None
3737
self.prediction_optimal_batchsize: Optional[int] = None
38+
self.use_cuda = False
3839

3940
def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndAnnot, prog,
40-
tempdir, trainingSplit):
41-
# print(f'Tensorflow trainModelDetails(batchSize={batchSize}, ...)')
42-
# make model
43-
num_classes = len(record['labels'])
44-
model = tf.keras.Sequential([
45-
tf.keras.layers.Rescaling(1.0 / 255),
46-
tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
47-
tf.keras.layers.MaxPooling2D(),
48-
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
49-
tf.keras.layers.MaxPooling2D(),
50-
tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
51-
tf.keras.layers.MaxPooling2D(),
52-
tf.keras.layers.Flatten(),
53-
# tf.keras.layers.Dropout(0.2),
54-
tf.keras.layers.Dense(128, activation='relu'),
55-
tf.keras.layers.Dense(num_classes)])
56-
prog.progress(0.2)
57-
model.compile(optimizer='adam',
58-
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
59-
metrics=['accuracy'])
41+
tempdir, trainingSplit, use_cuda):
42+
self.use_cuda = use_cuda
43+
44+
# Enable GPU memory growth globally to avoid precondition errors
45+
gpus = tf.config.list_physical_devices('GPU')
46+
if gpus and self.use_cuda:
47+
try:
48+
for gpu in gpus:
49+
tf.config.experimental.set_memory_growth(gpu, True)
50+
except RuntimeError as e:
51+
print(f"Could not set memory growth: {e}")
52+
if not self.use_cuda:
53+
tf.config.set_visible_devices([], 'GPU')
54+
device = "gpu" if use_cuda else "cpu"
55+
print(f"Using device: {device}")
56+
57+
# Dataset preparation (outside strategy scope)
58+
ds_h5 = record['ds']
59+
labelds_h5 = record['labelds']
60+
# Fully load to memory and break h5py reference
61+
ds_numpy = np.array(ds_h5[:])
62+
labelds_numpy = np.array(labelds_h5[:])
63+
64+
strategy = tf.distribute.MirroredStrategy()
65+
with strategy.scope():
66+
num_classes = len(record['labels'])
67+
model = tf.keras.Sequential([
68+
tf.keras.layers.Rescaling(1.0 / 255),
69+
tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
70+
tf.keras.layers.MaxPooling2D(),
71+
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
72+
tf.keras.layers.MaxPooling2D(),
73+
tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
74+
tf.keras.layers.MaxPooling2D(),
75+
tf.keras.layers.Flatten(),
76+
tf.keras.layers.Dense(128, activation='relu'),
77+
tf.keras.layers.Dense(num_classes)])
78+
prog.progress(0.2)
79+
model.compile(optimizer='adam',
80+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
81+
metrics=['accuracy'])
82+
6083
prog.progress(0.7)
61-
# generate split
62-
full_ds = tf.data.Dataset.from_tensor_slices((record['ds'], record['labelds']))
63-
full_ds = full_ds.shuffle(1000) # add seed=123 ?
64-
count = len(full_ds)
84+
# generate split using numpy arrays
85+
full_ds = tf.data.Dataset.from_tensor_slices((ds_numpy, labelds_numpy))
86+
full_ds = full_ds.shuffle(1000)
87+
count = len(ds_numpy)
6588
train_size = int(count * trainingSplit)
6689
if batchSize < 1:
6790
batchSize = self.findOptimalBatchSize(model, full_ds, training=True)
@@ -85,24 +108,53 @@ def trainModelDetails(self, record, annotationName, batchSize, epochs, itemsAndA
85108
self.saveModel(model, modelPath)
86109
return history, modelPath
87110

111+
def _get_device(self, use_cuda):
112+
if tf.config.list_physical_devices('GPU') and use_cuda:
113+
return '/GPU:0'
114+
return '/CPU:0'
115+
88116
def predictLabelsForItemDetails(
89-
self, batchSize, ds: h5py._hl.dataset.Dataset, item, model, prog,
117+
self, batchSize, ds: h5py._hl.dataset.Dataset, indices, item, model, use_cuda, prog,
90118
):
91-
# print(f'Tensorflow predictLabelsForItemDetails(batchSize={batchSize}, ...)')
92119
if batchSize < 1:
93120
batchSize = self.findOptimalBatchSize(
94121
model, tf.data.Dataset.from_tensor_slices(ds), training=False,
95122
)
96123
print(f'Optimal batch size for prediction = {batchSize}')
97-
predictions = model.predict(
98-
ds,
99-
batch_size=batchSize,
100-
callbacks=[_LogTensorflowProgress(
101-
prog, (ds.shape[0] + batchSize - 1) // batchSize, 0.05, 0.35, item)])
102-
prog.item_progress(item, 0.4)
103-
# softmax to scale to 0 to 1
104-
catWeights = tf.nn.softmax(predictions)
105-
return catWeights, predictions
124+
125+
device = self._get_device(use_cuda)
126+
with tf.device(device):
127+
# Create a dataset that pairs the data with their indices
128+
dataset = tf.data.Dataset.from_tensor_slices((ds, indices))
129+
dataset = dataset.batch(batchSize)
130+
131+
# Initialize arrays to store results
132+
all_predictions = []
133+
all_cat_weights = []
134+
all_indices = []
135+
136+
# Iterate through batches manually to keep track of indices
137+
for data, batch_indices in dataset:
138+
batch_predictions = model.predict(
139+
data,
140+
batch_size=batchSize,
141+
verbose=0) # Set verbose=0 to avoid multiple progress bars
142+
143+
# Apply softmax to scale to 0 to 1
144+
batch_cat_weights = tf.nn.softmax(batch_predictions)
145+
146+
all_predictions.append(batch_predictions)
147+
all_cat_weights.append(batch_cat_weights)
148+
all_indices.append(batch_indices)
149+
150+
prog.item_progress(item, 0.4)
151+
152+
# Concatenate all results
153+
predictions = tf.concat(all_predictions, axis=0)
154+
catWeights = tf.concat(all_cat_weights, axis=0)
155+
final_indices = tf.concat(all_indices, axis=0)
156+
157+
return catWeights.numpy(), predictions.numpy(), final_indices.numpy().astype(np.int64)
106158

107159
def findOptimalBatchSize(self, model, ds, training) -> int:
108160
if training and self.training_optimal_batchsize is not None:

superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ class _BayesianPatchTorchModel(bbald.consistent_mc_dropout.BayesianModule):
6666
# A Bayesian model that takes patches (2-dimensional shape) rather than vectors
6767
# (1-dimensional shape) as input. It is useful when feature != 'vector' and
6868
# SuperpixelClassificationBase.certainty == 'batchbald'.
69-
def __init__(self, num_classes: int) -> None:
69+
def __init__(self, num_classes: int, device: torch.device) -> None:
7070
# Set `self.device` as early as possible so that other code does not lock out
7171
# what we want.
72-
self.device: str = torch.device(
73-
('cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu'),
74-
)
72+
self.device : torch.device = device
7573
# print(f'Initial model.device = {self.device}')
7674
super(_BayesianPatchTorchModel, self).__init__()
7775

@@ -311,7 +309,10 @@ def trainModelDetails(
311309
prog: ProgressHelper,
312310
tempdir: str,
313311
trainingSplit: float,
312+
cuda : bool,
314313
):
314+
device = torch.device("cuda" if cuda else "cpu")
315+
print(f"Using device: {device}")
315316
# make model
316317
num_classes: int = len(record['labels'])
317318
model: torch.nn.Module
@@ -507,7 +508,7 @@ def fitModel(
507508
return history
508509

509510
def predictLabelsForItemDetails(
510-
self, batchSize: int, ds_h5, item, model: torch.nn.Module, prog: ProgressHelper,
511+
self, batchSize: int, ds_h5, item, model: torch.nn.Module, use_cuda : bool, prog: ProgressHelper,
511512
):
512513
# print(f'Torch predictLabelsForItemDetails(batchSize={batchSize}, ...)')
513514
num_superpixels: int = ds_h5.shape[0]
@@ -528,6 +529,9 @@ def predictLabelsForItemDetails(
528529
)
529530
if self.certainty == 'batchbald'
530531
else dict(num_superpixels=num_superpixels, num_classes=num_classes)
532+
# also set on model.device, ideally
533+
#device = torch.device("cuda" if use_cuda else "cpu")
534+
531535
)
532536
for cb in callbacks:
533537
cb.on_predict_begin(logs=logs)

0 commit comments

Comments
 (0)