Skip to content

Commit

Permalink
BootstrapSampler takes in a ratio as sample size parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
xapharius committed Feb 26, 2015
1 parent 2aa904b commit f7b92b5
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
Empty file.
32 changes: 16 additions & 16 deletions Engine/src/simulator/sampler/abstract_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ def __init__(self):
Constructor
'''
self.dataset = None
self.nrObs = None
self.nr_obs = None
self.data_hist = None
self.sample_hists = None
self.nrSamples = None
self.nr_samples = None

def bind_data(self, dataset):
'''
Bind dataset to Sampler in order to record statistics during sampling
'''
self.dataset = dataset
self.nrObs = len(dataset)
self.nrSamples = 0
self.data_hist = np.zeros(self.nrObs) # how often each element has been sampled
self.nr_obs = len(dataset)
self.nr_samples = 0
self.data_hist = np.zeros(self.nr_obs) # how often each element has been sampled
self.sample_hists = [] # array of histograms for each sample

@abstractmethod
Expand All @@ -46,42 +46,42 @@ def add_sample_histogram(self, sample_hist_arr):
'''
Add histogram for sample to the pool of sample histograms
'''
assert len(sample_hist_arr) == self.nrObs, "Sample histogram doesn't contain all obs from the dataset"
assert len(sample_hist_arr) == self.nr_obs, "Sample histogram doesn't contain all obs from the dataset"
self.sample_hists.append(sample_hist_arr)
self.data_hist += sample_hist_arr
self.nrSamples += 1
self.nr_samples += 1

def plot_data_histogram(self):
'''
Plot histogram showing the total sampled data
'''
plt.plot(range(0,self.nrObs), self.data_hist)
plt.plot(range(0,self.nr_obs), self.data_hist)
plt.show()

def plot_sample_histogram(self, sample_number):
'''
Samples generated are indexed starting from 0
'''
plt.plot(range(0,self.nrObs), self.sample_hists[sample_number])
plt.plot(range(0,self.nr_obs), self.sample_hists[sample_number])
plt.show()

def plot_sample_histograms(self):
'''
Plot all sample histograms on different subplots
Cracks for nrSamples < 3 (since subplot is then one dimensional), so return None
Cracks for nr_samples < 3 (since subplot is then one dimensional), so return None
'''
if self.nrSamples < 3: return
hEdgeSubplot= int(math.ceil(math.sqrt(self.nrSamples)))
vEdgeSubplot= int(round(math.sqrt(self.nrSamples)))
if self.nr_samples < 3: return
hEdgeSubplot= int(math.ceil(math.sqrt(self.nr_samples)))
vEdgeSubplot= int(round(math.sqrt(self.nr_samples)))
fig, axarr = plt.subplots(vEdgeSubplot, hEdgeSubplot, sharex = True, sharey = True)
for i in range(self.nrSamples):
for i in range(self.nr_samples):
ix = i/hEdgeSubplot
iy = i%hEdgeSubplot
axarr[ix, iy].plot(range(self.nrObs), self.sample_hists[i])
axarr[ix, iy].plot(range(self.nr_obs), self.sample_hists[i])
axarr[ix,iy].axes.get_xaxis().set_visible(False)
axarr[ix,iy].axes.get_yaxis().set_visible(False)
#delete empty subplots
for i in range(self.nrSamples, vEdgeSubplot*hEdgeSubplot):
for i in range(self.nr_samples, vEdgeSubplot*hEdgeSubplot):
ix = i/hEdgeSubplot
iy = i%hEdgeSubplot
fig.delaxes(axarr[ix,iy])
Expand Down
24 changes: 17 additions & 7 deletions Engine/src/simulator/sampler/bootstrap_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@
import random

class BootstrapSampler(AbstractSampler):
'''
Samples with replacement from given data at a size_ratio specified in the constructor
'''

def __init__(self, sampleSize = None):
def __init__(self, sample_size_ratio = None):
'''
Constructor
@param sample_ratio: size of the bootstrap sample as a percentage of the dataset's size [0,1].
Default value is to sample the same nr as the dataset's size (1)
'''
self.sampleSize = sampleSize
if sample_size_ratio is not None and (sample_size_ratio < 0 or sample_size_ratio > 1):
raise Exception("sample_size_ratio not between 0 and 1")
self.sample_size_ratio = sample_size_ratio
self.sample_size = None

def bind_data(self, dataset):
super(type(self), self).bind_data(dataset)
if self.sampleSize is None:
self.sampleSize = self.nrObs
if self.sample_size_ratio is None:
self.sample_size_ratio = 1

self.sample_size = int(round(self.sample_size_ratio * self.nr_obs))

def sample(self):
'''
Expand All @@ -29,11 +39,11 @@ def sample(self):
if self.dataset is None:
raise Exception("Data not bound")

sample_hist = self.nrObs * [0]
sample_hist = self.nr_obs * [0]
sample_data = [] # emtpy list, vstack later

for _ in range(self.sampleSize):
index = random.randint(0, self.nrObs-1)
for _ in range(self.sample_size):
index = random.randint(0, self.nr_obs-1)
sample_data.append(self.dataset[index])
sample_hist[index] += 1

Expand Down
7 changes: 3 additions & 4 deletions Engine/src/tests/simulator/sampler/abstract_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ def setUp(self):
self.asmpl.bind_data(self.dataset)

def test_bind_data(self):
assert self.asmpl.nrObs == 3
assert self.asmpl.nr_obs == 3
assert self.asmpl.dataset is not None
assert (self.asmpl.data_hist == [0,0,0]).all()
assert self.asmpl.nrSamples == 0
assert self.asmpl.nr_samples == 0

def test_add_sample_histogram(self):
self.asmpl.add_sample_histogram([1,1,2])
self.asmpl.add_sample_histogram([1,2,5])
print self.asmpl.data_hist
assert (self.asmpl.data_hist == [2,3,7]).all()
assert len(self.asmpl.sample_hists) == 2
assert self.asmpl.nrSamples == 2
assert self.asmpl.nr_samples == 2

'''
#Visual Test
Expand Down
26 changes: 12 additions & 14 deletions Engine/src/tests/simulator/sampler/bootstrap_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,29 @@ def tearDown(self):


def test_constructor(self):
bs = BootstrapSampler(100)
assert bs.sampleSize == 100
bs = BootstrapSampler(0.5)
assert bs.sample_size_ratio == 0.5
self.assertRaises(Exception, BootstrapSampler, -0.1)
self.assertRaises(Exception, BootstrapSampler, 10)

def test_data_not_bound(self):
bs = BootstrapSampler(100)
try:
bs.sample()
except Exception:
return
assert False
bs = BootstrapSampler()
self.assertRaises(Exception, bs.sample)

def test_sampleSize(self):
def test_sample_size(self):
bs = BootstrapSampler()
assert bs.sampleSize == None
assert bs.sample_size == None
bs.bind_data(self.data)
assert bs.sampleSize == self.data.shape[0]
assert bs.sample_size == self.data.shape[0]

def test_sample(self):
bs = BootstrapSampler(100)
bs = BootstrapSampler(0.5)
bs.bind_data(self.data)
sample = bs.sample()

assert bs.nrSamples == 1
assert bs.nr_samples == 1
assert len(bs.sample_hists) == 1
assert sample.shape[0] == 100
assert sample.shape[0] == int(round(0.5 * self.data.shape[0]))
assert sample.shape[1] == self.data.shape[1]
assert not (bs.data_hist == np.zeros(self.data.shape[0])).all()
'''
Expand Down

0 comments on commit f7b92b5

Please sign in to comment.