Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions data_preprocessing/custom_multiprocess.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,45 @@
'''
Custom non-daemonic Pool class
Code adapted from https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
Custom non-daemonic Pool class for all python version
Code adapted from https://github.com/LoLab-VU/PyDREAM/pull/17/commits
'''
import multiprocessing
import multiprocessing.pool


class NoDaemonProcess(multiprocessing.Process):
def _get_daemon(self):
class NonDaemonMixin(object):
@property
def daemon(self):
return False
def _set_daemon(self, value):

@daemon.setter
def daemon(self, val):
pass
daemon = property(_get_daemon, _set_daemon)

class MyPool(multiprocessing.pool.Pool):
Process = NoDaemonProcess

from multiprocessing import context


# Exists on all platforms
class NonDaemonSpawnProcess(NonDaemonMixin, context.SpawnProcess):
pass


class NonDaemonSpawnContext(context.SpawnContext):
Process = NonDaemonSpawnProcess


_nondaemon_context_mapper = {
'spawn': NonDaemonSpawnContext()
}


class DreamPool(multiprocessing.pool.Pool):
def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None, context=None):
if context is None:
context = multiprocessing.get_context()
context = _nondaemon_context_mapper[context._name]
super(DreamPool, self).__init__(processes=processes,
initializer=initializer,
initargs=initargs,
maxtasksperchild=maxtasksperchild,
context=context)
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def allocate_clients_to_threads(args):
client_info.put((client_dict[i], args))

# Start server and get initial outputs
pool = cm.MyPool(args.thread_number, init_process, (client_info, Client))
pool = cm.DreamPool(args.thread_number, init_process, (client_info, Client))
# init server
server_dict['save_path'] = '{}/logs/{}__{}_e{}_c{}'.format(os.getcwd(),
time.strftime("%Y%m%d_%H%M%S"), args.method, args.epochs, args.client_number)
Expand Down