Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 224 additions & 25 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import logging
import multiprocessing.managers
import os
import time
import traceback
import atexit
import sys
import tempfile
import threading
from typing import Any
Expand Down Expand Up @@ -79,6 +83,10 @@ def singletonProxy_release(self):
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False

def unsafe_hard_delete(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand why we need the unsafe_hard_delete changes? Its not really clear to me what behavior this enables which we can't already do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mainly because the way that models are passed around is directly a _SingletonProxy instead of _SingletonEntry so we would need a way to directly call delete with the _SingletonProxy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - lets at least give it a name like singletonProxy_unsafe_hard_delete. Otherwise we will run into issues if someone has an object with a function or property called unsafe_hard_delete, which seems like it could happen.

assert self._SingletonProxy_valid
self._SingletonProxy_entry.unsafe_hard_delete()

def __getattr__(self, name):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
Expand All @@ -105,13 +113,16 @@ def __dir__(self):
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
dir.append('singletonProxy_release')
dir.append('unsafe_hard_delete')
return dir


class _SingletonEntry:
"""Represents a single, refcounted entry in this process."""
def __init__(self, constructor, initialize_eagerly=True):
def __init__(
self, constructor, initialize_eagerly=True, hard_delete_callback=None):
self.constructor = constructor
self._hard_delete_callback = hard_delete_callback
self.refcount = 0
self.lock = threading.Lock()
if initialize_eagerly:
Expand Down Expand Up @@ -141,14 +152,28 @@ def unsafe_hard_delete(self):
if self.initialied:
del self.obj
self.initialied = False
if self._hard_delete_callback:
self._hard_delete_callback()


class _SingletonManager:
entries: Dict[Any, Any] = {}

def register_singleton(self, constructor, tag, initialize_eagerly=True):
def __init__(self):
self._hard_delete_callback = None

def set_hard_delete_callback(self, callback):
self._hard_delete_callback = callback

def register_singleton(
self,
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=None):
assert tag not in self.entries, tag
self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly)
self.entries[tag] = _SingletonEntry(
constructor, initialize_eagerly, hard_delete_callback)

def has_singleton(self, tag):
return tag in self.entries
Expand All @@ -160,7 +185,8 @@ def release_singleton(self, tag, obj):
return self.entries[tag].release(obj)

def unsafe_hard_delete_singleton(self, tag):
return self.entries[tag].unsafe_hard_delete()
self.entries[tag].unsafe_hard_delete()
self._hard_delete_callback()


_process_level_singleton_manager = _SingletonManager()
Expand Down Expand Up @@ -200,9 +226,99 @@ def __call__(self, *args, **kwargs):
def __getattr__(self, name):
return getattr(self._proxyObject, name)

def __setstate__(self, state):
self.__dict__.update(state)

def __getstate__(self):
return self.__dict__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is so that this is pickleable, but is it valid? Normally I'd expect this to not be pickleable since the proxy objects aren't necessarily valid in another context

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is exactly what was needed for the pickling stuff. It does seems to be valid in testing with the custom built beam version loaded on custom container.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would only be valid if you unpickle onto the same machine (and maybe even in the same process). Could you remind me what unpickling issues you ran into?


def get_auto_proxy_object(self):
return self._proxyObject

def unsafe_hard_delete(self):
try:
self._proxyObject.unsafe_hard_delete()
except (EOFError, ConnectionResetError, BrokenPipeError):
pass
except Exception as e:
logging.warning(
"Exception %s when trying to hard delete shared object proxy", e)


def _run_server_process(address_file, tag, constructor, authkey):
"""
Runs in a separate process.
Includes a 'Suicide Pact' monitor: If parent dies, I die.
"""
parent_pid = os.getppid()

def cleanup_files():
logging.info("Server process exiting. Deleting files for %s", tag)
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(address_file + ".error"):
os.remove(address_file + ".error")
except Exception:
pass

def handle_unsafe_hard_delete():
cleanup_files()
os._exit(0)

def _monitor_parent():
"""Checks if parent is alive every second."""
while True:
try:
os.kill(parent_pid, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we sending a kill signal to the parent process? Isn't this the opposite of what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not actually a kill signal but uses that interface to send a check, it will fail with OSError if the parent_pid is dead and if alive nothing happens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if alive nothing happens.

Could you help me understand why this happens? https://www.geeksforgeeks.org/python/python-os-kill-method/ seems to say this will actually send the kill signal. Does the parent just ignore it?

except OSError:
logging.warning(
"Process %s detected Parent %s died. Self-destructing.",
os.getpid(),
parent_pid)
cleanup_files()
os._exit(0)
time.sleep(0.5)

atexit.register(cleanup_files)

try:
t = threading.Thread(target=_monitor_parent, daemon=True)
t.start()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to start this after we've initialized our MPS object to avoid racy unsafe hard deletes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


logging.getLogger().setLevel(logging.INFO)
multiprocessing.current_process().authkey = authkey

serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=authkey)
_process_level_singleton_manager.set_hard_delete_callback(
handle_unsafe_hard_delete)
_process_level_singleton_manager.register_singleton(
constructor,
tag,
initialize_eagerly=True,
hard_delete_callback=handle_unsafe_hard_delete)

server = serving_manager.get_server()
logging.info(
'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address)

with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % server.address)
os.rename(address_file + '.tmp', address_file)

server.serve_forever()

except Exception:
tb = traceback.format_exc()
try:
with open(address_file + ".error.tmp", 'w') as fout:
fout.write(tb)
os.rename(address_file + ".error.tmp", address_file + ".error")
except Exception:
print(f"CRITICAL ERROR IN SHARED SERVER:\n{tb}", file=sys.stderr)
os._exit(1)


class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
Expand Down Expand Up @@ -252,7 +368,8 @@ def __init__(
tag: Any,
*,
path: str = tempfile.gettempdir(),
always_proxy: Optional[bool] = None):
always_proxy: Optional[bool] = None,
spawn_process: bool = False):
self._constructor = constructor
self._tag = tag
self._path = path
Expand All @@ -262,6 +379,7 @@ def __init__(
self._rpc_address = None
self._cross_process_lock = fasteners.InterProcessLock(
os.path.join(self._path, self._tag) + '.lock')
self._spawn_process = spawn_process

def _get_manager(self):
if self._manager is None:
Expand Down Expand Up @@ -301,6 +419,10 @@ def acquire(self):
# Caveat: They must always agree, as they will be ignored if the object
# is already constructed.
singleton = self._get_manager().acquire_singleton(self._tag)
# Trigger a sweep of zombie processes.
# calling active_children() has the side-effect of joining any finished
# processes, effectively reaping zombies from previous unsafe_hard_deletes.
if self._spawn_process: multiprocessing.active_children()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._spawn_process: multiprocessing.active_children()
if self._spawn_process:
multiprocessing.active_children()

style nit to be consistent with the rest of the repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

return _AutoProxyWrapper(singleton)

def release(self, obj):
Expand All @@ -315,25 +437,102 @@ def unsafe_hard_delete(self):
to this object exist, or (b) you are ok with all existing references to
this object throwing strange errors when derefrenced.
"""
self._get_manager().unsafe_hard_delete_singleton(self._tag)
try:
self._get_manager().unsafe_hard_delete_singleton(self._tag)
except (EOFError, ConnectionResetError, BrokenPipeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd typically expect the caller to catch/handle this. As it is, there is no indication passed back that this call failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Updated.

pass
except Exception as e:
logging.warning(
"Exception %s when trying to hard delete shared object %s",
e,
self._tag)

def _create_server(self, address_file):
# We need to be able to authenticate with both the manager and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
if self._spawn_process:
error_file = address_file + ".error"

if os.path.exists(error_file):
try:
os.remove(error_file)
except OSError:
pass

ctx = multiprocessing.get_context('spawn')
p = ctx.Process(
target=_run_server_process,
args=(address_file, self._tag, self._constructor, AUTH_KEY),
daemon=False # Must be False for nested proxies
)
p.start()
logging.info("Parent: Waiting for %s to write address file...", self._tag)

def cleanup_process():
if p.is_alive():
logging.info(
"Parent: Terminating server process %s for %s", p.pid, self._tag)
p.terminate()
p.join()
try:
if os.path.exists(address_file):
os.remove(address_file)
if os.path.exists(error_file):
os.remove(error_file)
except Exception:
pass

atexit.register(cleanup_process)

start_time = time.time()
last_log = start_time
while True:
if os.path.exists(address_file):
break

if os.path.exists(error_file):
with open(error_file, 'r') as f:
error_msg = f.read()
try:
os.remove(error_file)
except OSError:
pass

if p.is_alive(): p.terminate()
raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}")

if not p.is_alive():
exit_code = p.exitcode
raise RuntimeError(
"Shared Server Process died unexpectedly"
f" with exit code {exit_code}")

if time.time() - last_log > 300:
logging.warning(
"Still waiting for %s to initialize... %ss elapsed)",
self._tag,
int(time.time() - start_time))
last_log = time.time()

time.sleep(0.05)

logging.info('External process successfully started for %s', self._tag)
else:
# We need to be able to authenticate with both the manager
# and the process.
self._serving_manager = _SingletonRegistrar(
address=('localhost', 0), authkey=AUTH_KEY)
multiprocessing.current_process().authkey = AUTH_KEY
# Initialize eagerly to avoid acting as the server if there are issues.
# Note, however, that _create_server itself is called lazily.
_process_level_singleton_manager.register_singleton(
self._constructor, self._tag, initialize_eagerly=True)
self._server = self._serving_manager.get_server()
logging.info(
'Starting proxy server at %s for shared %s',
self._server.address,
self._tag)
with open(address_file + '.tmp', 'w') as fout:
fout.write('%s:%d' % self._server.address)
os.rename(address_file + '.tmp', address_file)
t = threading.Thread(target=self._server.serve_forever, daemon=True)
t.start()
logging.info('Done starting server')
Loading
Loading