Skip to content

Commit e41339f

Browse files
committed
ENH: Make random numbers multithreading registrations deterministic
Stopped using the global instance of MersenneTwisterRandomVariateGenerator. Stopped calling its `GetInstance()`. Added a deterministic default-constructed `m_RandomVariateGenerator` to ElastixBase. Added pointers to this generator to AdvancedImageToImageMetric, ImageRandomSamplerBase, and CMAEvolutionStrategyOptimizer. Added `SetRandomVariateGenerator` member functions to these classes (using a default-constructed generator when `SetRandomVariateGenerator` is not yet called). Also removed the one `MersenneTwisterRandomVariateGenerator::New()` call from elastix, which was for a local generator in `ImageRandomSamplerBase::GenerateRandomNumberList()`. Instead, added `m_Seed` to ImageRandomSamplerBase, and used that seed for a default-constructed local generator. Aims to make the results of running multiple registrations parallel (multi-threaded) within a single process deterministic. Triggered by pull request InsightSoftwareConsortium/ITK#5287 "Deterministic multithreading usage of itkMersenneTwisterRandomVariateGenerator.cxx", Michal Meszaros, Mar 21, 2025.
1 parent 629be60 commit e41339f

File tree

34 files changed

+166
-86
lines changed

34 files changed

+166
-86
lines changed

Common/CostFunctions/itkAdvancedImageToImageMetric.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "itkAdvancedBSplineDeformableTransform.h"
3636
#include "itkAdvancedCombinationTransform.h"
3737

38+
#include "elxDefaultConstruct.h"
39+
3840
#include <cassert>
3941
#include <memory> // For unique_ptr.
4042
#include <typeinfo>
@@ -317,6 +319,12 @@ class ITK_TEMPLATE_EXPORT AdvancedImageToImageMetric : public ImageToImageMetric
317319
virtual void
318320
BeforeThreadedGetValueAndDerivative(const TransformParametersType & parameters) const;
319321

322+
void
323+
SetRandomVariateGenerator(Statistics::MersenneTwisterRandomVariateGenerator & randomVariateGenerator)
324+
{
325+
m_RandomVariateGenerator = &randomVariateGenerator;
326+
}
327+
320328
protected:
321329
/** Constructor. */
322330
AdvancedImageToImageMetric();
@@ -598,6 +606,19 @@ class ITK_TEMPLATE_EXPORT AdvancedImageToImageMetric : public ImageToImageMetric
598606
itkExceptionMacro("Intentionally left unimplemented!");
599607
}
600608

609+
Statistics::MersenneTwisterRandomVariateGenerator &
610+
GetRandomVariateGenerator()
611+
{
612+
return *m_RandomVariateGenerator;
613+
}
614+
615+
// Note: Bypasses logical const-correctness
616+
Statistics::MersenneTwisterRandomVariateGenerator &
617+
GetMutableRandomVariateGenerator() const
618+
{
619+
return *m_RandomVariateGenerator;
620+
}
621+
601622
// Protected using-declaration, to avoid `-Woverloaded-virtual` warnings from GCC (GCC 11.4) or clang (macos-12).
602623
using Superclass::SetTransform;
603624

@@ -627,6 +648,10 @@ class ITK_TEMPLATE_EXPORT AdvancedImageToImageMetric : public ImageToImageMetric
627648

628649
MovingImageDerivativeScalesType m_MovingImageDerivativeScales{ MovingImageDerivativeScalesType::Filled(1.0) };
629650

651+
mutable elastix::DefaultConstruct<Statistics::MersenneTwisterRandomVariateGenerator>
652+
m_DefaultRandomVariateGenerator{};
653+
Statistics::MersenneTwisterRandomVariateGenerator * m_RandomVariateGenerator{ &m_DefaultRandomVariateGenerator };
654+
630655
// Private using-declarations, to avoid `-Woverloaded-virtual` warnings from GCC (GCC 11.4) or clang (macos-12).
631656
using Superclass::TransformPoint;
632657

Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ GTEST_TEST(ImageRandomCoordinateSampler, CheckImageValuesOfSamples)
4343
using ImageType = itk::Image<PixelType>;
4444
using SamplerType = itk::ImageRandomCoordinateSampler<ImageType>;
4545

46-
// Use a fixed seed, in order to have a reproducible sampler output.
47-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
48-
4946
const auto image =
5047
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(minimumImageSizeValue));
5148

49+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
50+
randomVariateGenerator.SetSeed(1);
5251
elx::DefaultConstruct<SamplerType> sampler{};
5352

53+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
5454
const size_t numberOfSamples{ 3 };
5555
sampler.SetNumberOfSamples(numberOfSamples);
5656
sampler.SetInput(image);
@@ -84,9 +84,11 @@ GTEST_TEST(ImageRandomCoordinateSampler, SetSeedMakesRandomizationDeterministic)
8484
for (const SamplerType::SeedIntegerType seed : { 0, 1 })
8585
{
8686
const auto generateSamples = [seed, image] {
87+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
88+
randomVariateGenerator.SetSeed(seed);
8789
elx::DefaultConstruct<SamplerType> sampler{};
8890

89-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed);
91+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
9092
sampler.SetInput(image);
9193
sampler.Update();
9294
return std::move(Deref(sampler.GetOutput()).CastToSTLContainer());
@@ -114,9 +116,10 @@ GTEST_TEST(ImageRandomCoordinateSampler, HasSameOutputWhenUsingMultiThread)
114116
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(minimumImageSizeValue));
115117

116118
const auto generateSamples = [image](const bool useMultiThread) {
117-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
119+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
118120

119121
elx::DefaultConstruct<SamplerType> sampler{};
122+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
120123
sampler.SetUseMultiThread(useMultiThread);
121124
sampler.SetInput(image);
122125
sampler.Update();

Common/GTesting/itkImageRandomSamplerSparseMaskGTest.cxx

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ GTEST_TEST(ImageRandomSamplerSparseMask, CheckImageValuesOfSamples)
4848
using ImageType = itk::Image<PixelType, Dimension>;
4949
using MaskSpatialObjectType = itk::ImageMaskSpatialObject<Dimension>;
5050

51-
// Use a fixed seed, in order to have a reproducible sampler output.
52-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
53-
5451
const auto imageSize = ImageType::SizeType::Filled(minimumImageSizeValue);
5552
const auto image = CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(imageSize);
5653

@@ -61,9 +58,12 @@ GTEST_TEST(ImageRandomSamplerSparseMask, CheckImageValuesOfSamples)
6158
maskSpatialObject->SetImage(maskImage);
6259
maskSpatialObject->Update();
6360

61+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
62+
randomVariateGenerator.SetSeed(1);
6463
elx::DefaultConstruct<itk::ImageRandomSamplerSparseMask<ImageType>> sampler{};
6564

6665
const size_t numberOfSamples{ 3 };
66+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
6767
sampler.SetInput(image);
6868
sampler.SetMask(maskSpatialObject);
6969
sampler.SetNumberOfSamples(numberOfSamples);
@@ -104,9 +104,11 @@ GTEST_TEST(ImageRandomSamplerSparseMask, SetSeedMakesRandomizationDeterministic)
104104
for (const SamplerType::SeedIntegerType seed : { 0, 1 })
105105
{
106106
const auto generateSamples = [seed, image, maskSpatialObject] {
107-
elx::DefaultConstruct<SamplerType> sampler{};
107+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
108+
elx::DefaultConstruct<SamplerType> sampler{};
108109

109-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed);
110+
randomVariateGenerator.SetSeed(seed);
111+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
110112
sampler.SetInput(image);
111113
sampler.SetMask(maskSpatialObject);
112114
sampler.Update();
@@ -144,8 +146,9 @@ GTEST_TEST(ImageRandomSamplerSparseMask, HasSameOutputWhenUsingMultiThread)
144146
maskSpatialObject->Update();
145147

146148
const auto generateSamples = [image, maskSpatialObject](const bool useMultiThread) {
147-
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
148-
elx::DefaultConstruct<SamplerType> sampler{};
149+
elx::DefaultConstruct<MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
150+
elx::DefaultConstruct<SamplerType> sampler{};
151+
sampler.SetRandomVariateGenerator(randomVariateGenerator);
149152
sampler.SetUseMultiThread(useMultiThread);
150153
sampler.SetInput(image);
151154
sampler.SetMask(maskSpatialObject);

Common/ImageSamplers/itkImageRandomCoordinateSampler.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ class ITK_TEMPLATE_EXPORT ImageRandomCoordinateSampler : public ImageRandomSampl
131131
return interpolator;
132132
}();
133133

134-
RandomGeneratorPointer m_RandomGenerator{ RandomGeneratorType::GetInstance() };
135-
InputImageSpacingType m_SampleRegionSize{ itk::MakeFilled<InputImageSpacingType>(1.0) };
134+
InputImageSpacingType m_SampleRegionSize{ itk::MakeFilled<InputImageSpacingType>(1.0) };
136135

137136
/** Generate the two corners of a sampling region, given the two corners
138137
* of an image. If UseRandomSampleRegion=false, the smallesPoint and largestPoint

Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,12 @@ ImageRandomCoordinateSampler<TInputImage>::GenerateRandomCoordinate(
207207
const InputImageContinuousIndexType & largestContIndex,
208208
InputImageContinuousIndexType & randomContIndex)
209209
{
210+
Statistics::MersenneTwisterRandomVariateGenerator & randomVariateGenerator = Superclass::GetRandomVariateGenerator();
211+
210212
for (unsigned int i = 0; i < InputImageDimension; ++i)
211213
{
212214
randomContIndex[i] = static_cast<InputImagePointValueType>(
213-
this->m_RandomGenerator->GetUniformVariate(smallestContIndex[i], largestContIndex[i]));
215+
randomVariateGenerator.GetUniformVariate(smallestContIndex[i], largestContIndex[i]));
214216
}
215217
} // end GenerateRandomCoordinate()
216218

@@ -268,7 +270,6 @@ ImageRandomCoordinateSampler<TInputImage>::PrintSelf(std::ostream & os, Indent i
268270
Superclass::PrintSelf(os, indent);
269271

270272
os << indent << "Interpolator: " << this->m_Interpolator.GetPointer() << std::endl;
271-
os << indent << "RandomGenerator: " << this->m_RandomGenerator.GetPointer() << std::endl;
272273

273274
} // end PrintSelf()
274275

Common/ImageSamplers/itkImageRandomSamplerBase.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "itkImageSamplerBase.h"
2222
#include <itkMersenneTwisterRandomVariateGenerator.h>
23+
#include "elxDefaultConstruct.h"
2324
#include <optional>
2425

2526
namespace itk
@@ -94,6 +95,13 @@ class ITK_TEMPLATE_EXPORT ImageRandomSamplerBase : public ImageSamplerBase<TInpu
9495
return m_OptionalSeed;
9596
}
9697

98+
void
99+
SetRandomVariateGenerator(Statistics::MersenneTwisterRandomVariateGenerator & randomVariateGenerator)
100+
{
101+
m_RandomVariateGenerator = &randomVariateGenerator;
102+
}
103+
104+
97105
/** The input image dimension. */
98106
itkStaticConstMacro(InputImageDimension, unsigned int, Superclass::InputImageDimension);
99107

@@ -108,6 +116,12 @@ class ITK_TEMPLATE_EXPORT ImageRandomSamplerBase : public ImageSamplerBase<TInpu
108116
void
109117
GenerateRandomNumberList();
110118

119+
Statistics::MersenneTwisterRandomVariateGenerator &
120+
GetRandomVariateGenerator()
121+
{
122+
return *m_RandomVariateGenerator;
123+
}
124+
111125
/** PrintSelf. */
112126
void
113127
PrintSelf(std::ostream & os, Indent indent) const override;
@@ -117,6 +131,10 @@ class ITK_TEMPLATE_EXPORT ImageRandomSamplerBase : public ImageSamplerBase<TInpu
117131

118132
private:
119133
std::optional<SeedIntegerType> m_OptionalSeed{};
134+
SeedIntegerType m_Seed{ 121212 + 1 };
135+
136+
elastix::DefaultConstruct<Statistics::MersenneTwisterRandomVariateGenerator> m_DefaultRandomVariateGenerator{};
137+
Statistics::MersenneTwisterRandomVariateGenerator * m_RandomVariateGenerator{ &m_DefaultRandomVariateGenerator };
120138
};
121139

122140
} // end namespace itk

Common/ImageSamplers/itkImageRandomSamplerBase.hxx

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,19 @@ template <typename TInputImage>
4545
void
4646
ImageRandomSamplerBase<TInputImage>::GenerateRandomNumberList()
4747
{
48-
/** Create a random number generator. Also used in the ImageRandomConstIteratorWithIndex. */
49-
const auto localGenerator = Statistics::MersenneTwisterRandomVariateGenerator::New();
50-
51-
if (m_OptionalSeed)
52-
{
53-
localGenerator->SetSeed(*m_OptionalSeed);
54-
}
48+
elastix::DefaultConstruct<Statistics::MersenneTwisterRandomVariateGenerator> randomVariateGenerator{};
49+
randomVariateGenerator.SetSeed(m_OptionalSeed.value_or(++m_Seed));
5550

5651
/** Clear the random number list. */
5752
this->m_RandomNumberList.clear();
5853
this->m_RandomNumberList.reserve(this->m_NumberOfSamples);
5954

6055
/** Fill the list with random numbers. */
6156
const auto numPixels = static_cast<double>(this->GetCroppedInputImageRegion().GetNumberOfPixels());
62-
localGenerator->GetVariateWithOpenRange(numPixels - 0.5); // dummy jump
57+
randomVariateGenerator.GetVariateWithOpenRange(numPixels - 0.5); // dummy jump
6358
for (unsigned long i = 0; i < this->m_NumberOfSamples; ++i)
6459
{
65-
const double randomPosition = localGenerator->GetVariateWithOpenRange(numPixels - 0.5);
60+
const double randomPosition = randomVariateGenerator.GetVariateWithOpenRange(numPixels - 0.5);
6661
this->m_RandomNumberList.push_back(randomPosition);
6762
}
6863
}

Common/ImageSamplers/itkImageRandomSamplerSparseMask.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class ITK_TEMPLATE_EXPORT ImageRandomSamplerSparseMask : public ImageRandomSampl
9494
void
9595
GenerateData() override;
9696

97-
RandomGeneratorPointer m_RandomGenerator{ RandomGeneratorType::GetInstance() };
9897
InternalFullSamplerPointer m_InternalFullSampler{ InternalFullSamplerType::New() };
9998

10099
private:

Common/ImageSamplers/itkImageRandomSamplerSparseMask.hxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
7575
const ImageSampleContainerType & allValidSamples = Deref(this->m_InternalFullSampler->GetOutput());
7676
unsigned long numberOfValidSamples = allValidSamples.Size();
7777

78+
Statistics::MersenneTwisterRandomVariateGenerator & randomVariateGenerator = Superclass::GetRandomVariateGenerator();
7879

7980
/** If desired we exercise a multi-threaded version. */
8081
if (Superclass::m_UseMultiThread)
@@ -84,7 +85,7 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
8485

8586
for (unsigned int i = 0; i < Superclass::m_NumberOfSamples; ++i)
8687
{
87-
m_RandomIndices.push_back(m_RandomGenerator->GetIntegerVariate(numberOfValidSamples - 1));
88+
m_RandomIndices.push_back(randomVariateGenerator.GetIntegerVariate(numberOfValidSamples - 1));
8889
}
8990

9091
auto & samples = sampleContainer.CastToSTLContainer();
@@ -101,7 +102,7 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
101102

102103
for (unsigned int i = 0; i < Superclass::m_NumberOfSamples; ++i)
103104
{
104-
unsigned long randomIndex = this->m_RandomGenerator->GetIntegerVariate(numberOfValidSamples - 1);
105+
unsigned long randomIndex = randomVariateGenerator.GetIntegerVariate(numberOfValidSamples - 1);
105106
sampleVector.push_back(allValidSamples.ElementAt(randomIndex));
106107
}
107108

@@ -157,7 +158,6 @@ ImageRandomSamplerSparseMask<TInputImage>::PrintSelf(std::ostream & os, Indent i
157158
Superclass::PrintSelf(os, indent);
158159

159160
os << indent << "InternalFullSampler: " << this->m_InternalFullSampler.GetPointer() << std::endl;
160-
os << indent << "RandomGenerator: " << this->m_RandomGenerator.GetPointer() << std::endl;
161161

162162
} // end PrintSelf()
163163

Common/ImageSamplers/itkMultiInputImageRandomCoordinateSampler.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ class ITK_TEMPLATE_EXPORT MultiInputImageRandomCoordinateSampler : public ImageR
133133
return interpolator;
134134
}();
135135

136-
RandomGeneratorPointer m_RandomGenerator{ RandomGeneratorType::GetInstance() };
137-
InputImageSpacingType m_SampleRegionSize{ itk::MakeFilled<InputImageSpacingType>(1.0) };
136+
InputImageSpacingType m_SampleRegionSize{ itk::MakeFilled<InputImageSpacingType>(1.0) };
138137

139138
/** Generate the two corners of a sampling region. */
140139
virtual void

0 commit comments

Comments
 (0)