Skip to content

Commit 4142cc7

Browse files
committed
Remove numba njit and use vectorized version
1 parent fd8578a commit 4142cc7

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

lectures/schelling.md

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ kernelspec:
2020
</div>
2121
```
2222

23-
# Racial Segregation
23+
# Racial Segregation
2424

2525
```{index} single: Schelling Segregation Model
2626
```
@@ -191,30 +191,30 @@ class Agent:
191191
b = (self.location[1] - other.location[1])**2
192192
return sqrt(a + b)
193193
194-
def happy(self,
194+
def happy(self,
195195
agents, # List of other agents
196196
num_neighbors=10, # No. of agents viewed as neighbors
197197
require_same_type=5): # How many neighbors must be same type
198198
"""
199199
True if a sufficient number of nearest neighbors are of the same
200-
type.
200+
type.
201201
"""
202202
203203
distances = []
204-
204+
205205
# Distances is a list of pairs (d, agent), where d is distance from
206206
# agent to self
207207
for agent in agents:
208208
if self != agent:
209209
distance = self.get_distance(agent)
210210
distances.append((distance, agent))
211-
211+
212212
# Sort from smallest to largest, according to distance
213213
distances.sort()
214-
214+
215215
# Extract the neighboring agents
216216
neighbors = [agent for d, agent in distances[:num_neighbors]]
217-
217+
218218
# Count how many neighbors have the same type as self
219219
num_same_type = sum(self.type == agent.type for agent in neighbors)
220220
return num_same_type >= require_same_type
@@ -248,9 +248,9 @@ def plot_distribution(agents, cycle_num):
248248
fig, ax = plt.subplots()
249249
plot_args = {'markersize': 8, 'alpha': 0.8}
250250
ax.set_facecolor('azure')
251-
ax.plot(x_values_0, y_values_0,
251+
ax.plot(x_values_0, y_values_0,
252252
'o', markerfacecolor='orange', **plot_args)
253-
ax.plot(x_values_1, y_values_1,
253+
ax.plot(x_values_1, y_values_1,
254254
'o', markerfacecolor='green', **plot_args)
255255
ax.set_title(f'Cycle {cycle_num-1}')
256256
plt.show()
@@ -274,24 +274,24 @@ The real code is below
274274
```{code-cell} ipython3
275275
def run_simulation(num_of_type_0=600,
276276
num_of_type_1=600,
277-
max_iter=100_000, # Maximum number of iterations
278-
set_seed=1234):
277+
max_iter=100_000, # Maximum number of iterations
278+
set_seed=1234):
279279
280280
# Set the seed for reproducibility
281-
seed(set_seed)
282-
281+
seed(set_seed)
282+
283283
# Create a list of agents of type 0
284284
agents = [Agent(0) for i in range(num_of_type_0)]
285285
# Append a list of agents of type 1
286286
agents.extend(Agent(1) for i in range(num_of_type_1))
287287
288288
# Initialize a counter
289289
count = 1
290-
290+
291291
# Plot the initial distribution
292292
plot_distribution(agents, count)
293-
294-
# Loop until no agent wishes to move
293+
294+
# Loop until no agent wishes to move
295295
while count < max_iter:
296296
print('Entering loop ', count)
297297
count += 1
@@ -303,15 +303,15 @@ def run_simulation(num_of_type_0=600,
303303
no_one_moved = False
304304
if no_one_moved:
305305
break
306-
306+
307307
# Plot final distribution
308308
plot_distribution(agents, count)
309309
310310
if count < max_iter:
311311
print(f'Converged after {count} iterations.')
312312
else:
313313
print('Hit iteration bound and terminated.')
314-
314+
315315
```
316316

317317
Let's have a look at the results.
@@ -346,7 +346,7 @@ The object oriented style that we used for coding above is neat but harder to
346346
optimize than procedural code (i.e., code based around functions rather than
347347
objects and methods).
348348

349-
Try writing a new version of the model that stores
349+
Try writing a new version of the model that stores
350350

351351
* the locations of all agents as a 2D NumPy array of floats.
352352
* the types of all agents as a flat NumPy array of integers.
@@ -375,7 +375,6 @@ solution here
375375

376376
```{code-cell} ipython3
377377
from numpy.random import uniform, randint
378-
from numba import njit
379378
380379
n = 1000 # number of agents (agents = 0, ..., n-1)
381380
k = 10 # number of agents regarded as neighbors
@@ -386,13 +385,10 @@ def initialize_state():
386385
types = randint(0, high=2, size=n) # label zero or one
387386
return locations, types
388387
389-
@njit # Use Numba to accelerate computation
388+
390389
def compute_distances_from_loc(loc, locations):
391-
" Compute distance from location loc to all other points. "
392-
distances = np.empty(n)
393-
for j in range(n):
394-
distances[j] = np.linalg.norm(loc - locations[j, :])
395-
return distances
390+
""" Compute distance from location loc to all other points. """
391+
return np.linalg.norm(loc - locations, axis=1)
396392
397393
def get_neighbors(loc, locations):
398394
" Get all neighbors of a given location. "
@@ -417,7 +413,7 @@ def count_happy(locations, types):
417413
for i in range(n):
418414
happy_sum += is_happy(i, locations, types)
419415
return happy_sum
420-
416+
421417
def update_agent(i, locations, types):
422418
" Move agent if unhappy. "
423419
moved = False
@@ -432,11 +428,11 @@ def plot_distribution(locations, types, title, savepdf=False):
432428
colors = 'orange', 'green'
433429
for agent_type, color in zip((0, 1), colors):
434430
idx = (types == agent_type)
435-
ax.plot(locations[idx, 0],
436-
locations[idx, 1],
437-
'o',
431+
ax.plot(locations[idx, 0],
432+
locations[idx, 1],
433+
'o',
438434
markersize=8,
439-
markerfacecolor=color,
435+
markerfacecolor=color,
440436
alpha=0.8)
441437
ax.set_title(title)
442438
plt.show()
@@ -458,15 +454,15 @@ def sim_random_select(max_iter=100_000, flip_prob=0.01, test_freq=10_000):
458454
i = randint(0, n)
459455
moved = update_agent(i, locations, types)
460456
461-
if flip_prob > 0:
457+
if flip_prob > 0:
462458
# flip agent i's type with probability epsilon
463459
U = uniform()
464460
if U < flip_prob:
465461
current_type = types[i]
466462
types[i] = 0 if current_type == 1 else 1
467463
468464
# Every so many updates, plot and test for convergence
469-
if current_iter % test_freq == 0:
465+
if current_iter % test_freq == 0:
470466
cycle = current_iter / n
471467
plot_distribution(locations, types, f'iteration {current_iter}')
472468
if count_happy(locations, types) == n:

0 commit comments

Comments
 (0)