Skip to content

Adjust straight guide to target total number of threads. #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions acc/components/optics/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from math import ceil, sqrt, tanh
from numba import cuda, float32, void
from numba import cuda, float32, int64, void
from time import time

from mcni.AbstractComponent import AbstractComponent
Expand Down Expand Up @@ -131,20 +131,26 @@ def propagate(

@cuda.jit(void(float32, float32, float32, float32, float32,
float32, float32, float32, float32, float32,
float32[:, :]))
float32[:, :], int64))
def process_kernel(
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
neutrons
neutrons, batch_size
):
x = cuda.grid(1)
if x < len(neutrons):
neutron_count = len(neutrons)
work_unit = cuda.grid(1)
neutron_start = work_unit * batch_size
neutron_end = min(neutron_count, neutron_start + batch_size)
for neutron_index in range(neutron_start, neutron_end):
propagate(
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
neutrons[x]
neutrons[neutron_index]
)
return


thread_count_target = 1e5
threads_per_block = 512


def call_process(
Expand All @@ -153,13 +159,15 @@ def call_process(
in_neutrons
):
neutron_count = len(in_neutrons)
threads_per_block = 512
number_of_blocks = ceil(neutron_count / threads_per_block)
print("{} blocks, {} threads".format(number_of_blocks, threads_per_block))
process_kernel[number_of_blocks, threads_per_block](
thread_count = min(thread_count_target, neutron_count)
block_count = ceil(thread_count / threads_per_block)
neutrons_per_thread = ceil(neutron_count / thread_count)
print("{} blocks, {} threads per block, {} neutrons per thread".format(
block_count, threads_per_block, neutrons_per_thread))
process_kernel[block_count, threads_per_block](
ww, hh, hw1, hh1, l,
R0, Qc, alpha, m, W,
in_neutrons
in_neutrons, neutrons_per_thread
)
cuda.synchronize()

Expand Down