From 41687c25963c19725721b240ee5de902c5294c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Pottier?= <48072795+lpottier@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:03:51 -0700 Subject: [PATCH] Add AMSMonitor interface and unify both RMQ API (#32) (#62) Signed-off-by: Loic Pottier --- pyproject.toml | 11 +- src/AMSWorkflow/ams/monitor.py | 312 ++++++++++++++++ src/AMSWorkflow/ams/rmq.py | 529 ++++++++++++++++++++++++++- src/AMSWorkflow/ams/rmq_async.py | 401 -------------------- src/AMSWorkflow/ams/stage.py | 360 ++++++++---------- src/AMSWorkflow/ams_wf/AMSBroker.py | 4 +- src/AMSWorkflow/ams_wf/AMSDBStage.py | 10 +- 7 files changed, 1006 insertions(+), 621 deletions(-) create mode 100644 src/AMSWorkflow/ams/monitor.py delete mode 100644 src/AMSWorkflow/ams/rmq_async.py diff --git a/pyproject.toml b/pyproject.toml index 23502dd3..3bea2d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,13 +77,18 @@ exclude = [ # E226: Missing white space around arithmetic operator [tool.ruff] -ignore = ["E501", "W503", "E226", "BLK100", "E203"] +lint.ignore = ["E501", "E226", "E203"] show-fixes = true - +exclude = [ + ".git", + "__pycache__", + "*.egg-info", + "build" +] # change the default line length number or characters. line-length = 120 +lint.select = ['E', 'F', 'W', 'A', 'PLC', 'PLE', 'PLW', 'I', 'N', 'Q'] [tool.yapf] ignore = ["E501", "W503", "E226", "BLK100", "E203"] column_limit = 120 - diff --git a/src/AMSWorkflow/ams/monitor.py b/src/AMSWorkflow/ams/monitor.py new file mode 100644 index 00000000..903ec3e1 --- /dev/null +++ b/src/AMSWorkflow/ams/monitor.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import datetime +import json +import logging +import multiprocessing +import threading +import time +from typing import Callable, List, Union + + +class AMSMonitor: + """ + AMSMonitor can be used to decorate class methods and will + record automatically the duration of the tasks in a hashmap + with timestamp. The decorator will also automatically + record the values of all attributes of the class. + + class ExampleTask1(Task): + def __init__(self): + self.total_bytes = 0 + self.total_bytes2 = 0 + + # @AMSMonitor() would record all attributes + # (total_bytes and total_bytes2) and the duration + # of the block under the name amsmonitor_duration. + # Each time the same block (function in class or + # predefined tag) is being monitored, AMSMonitor + # create a new record with a timestamp (see below). + # + # @AMSMonitor(accumulate=True) records also all + # attributes but does not create a new record each + # time that block is being monitored, the first + # timestamp is always being used and only + # amsmonitor_duration is being accumulated. + # The user-managed attributes (like total_bytes + # and total_bytes2 ) are not being accumulated. + # By default, accumulate=False. + + # Example: we do not want to record total_bytes + # but just total_bytes2 + @AMSMonitor(record=["total_bytes2"]) + def __call__(self): + i = 0 + # Here we have to manually provide the current object being monitored + with AMSMonitor(obj=self, tag="while_loop"): + while (i<=3): + self.total_bytes += 10 + self.total_bytes2 = 1 + i += 1 + + Each time `ExampleTask1()` is being called, AMSMonitor will + populate `_stats` as follows (showed with two calls here): + { + "ExampleTask1": { + "while_loop": { + "02/29/2024-19:27:53": { + "total_bytes2": 30, + "amsmonitor_duration": 4.004607439041138 + } + }, + "__call__": { + "02/29/2024-19:29:24": { + "total_bytes2": 30, + "amsmonitor_duration": 4.10461138 + } + } + } + } + + Attributes: + record: attributes to record, if None, all attributes + will be recorded, except objects (e.g., multiprocessing.Queue) + which can cause problem. if empty ([]), no attributes will + be recorded, only amsmonitor_duration will be recorded. + accumulate: If True, AMSMonitor will accumulate recorded + data instead of recording a new timestamp for + any subsequent call of AMSMonitor on the same method. + We accumulate only records managed by AMSMonitor, like + amsmonitor_duration. We do not accumulate records + from the monitored class/function. + obj: Mandatory if using `with` statement, `object` is + the main object should be provided (i.e., self). + tag: Mandatory if using `with` statement, `tag` is the + name that will appear in the record for that + context manager statement. + """ + + _manager = multiprocessing.Manager() + _stats = _manager.dict() + _ts_format = "%m/%d/%Y-%H:%M:%S" + _reserved_keys = ["amsmonitor_duration"] + _lock = threading.Lock() + _count = 0 + + def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs): + self.accumulate = accumulate + self.kwargs = kwargs + self.record = record + if not isinstance(record, list): + self.record = None + # We make sure we do not overwrite protected attributes managed by AMSMonitor + if self.record: + self.record = self._remove_reserved_keys(self.record) + self.object = obj + self.start_time = 0 + self.internal_ts = 0 + self.tag = tag + AMSMonitor._count += 1 + self.logger = logger if logger else logging.getLogger(__name__) + + def __str__(self) -> str: + return AMSMonitor.info() if AMSMonitor._stats != {} else "{}" + + def __repr__(self) -> str: + return self.__str__() + + def lock(self): + AMSMonitor._lock.acquire() + + def unlock(self): + AMSMonitor._lock.release() + + def __enter__(self): + if not self.object or not self.tag: + self.logger.error('missing parameter "object" or "tag" when using context manager syntax') + return + self.start_monitor() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_monitor() + + @classmethod + def info(cls) -> str: + s = "" + if cls._stats == {}: + return "{}" + for k, v in cls._stats.items(): + s += f"{k}\n" + for i, j in v.items(): + s += f" {i}\n" + for p, z in j.items(): + s += f" {p:<10}\n" + for r, q in z.items(): + s += f" {r:<30} => {q}\n" + return s.rstrip() + + @classmethod + @property + def stats(cls): + return AMSMonitor._stats + + @classmethod + @property + def format_ts(cls): + return AMSMonitor._ts_format + + @classmethod + def convert_ts(cls, ts: str) -> datetime.datetime: + return datetime.strptime(ts, cls.format_ts) + + @classmethod + def json(cls, json_output: str): + """ + Write the collected metrics to a JSON file. + """ + with open(json_output, "w") as fp: + # we have to use .copy() as DictProxy is not serializable + json.dump(cls._stats.copy(), fp, indent=4) + # To avoid partial line at the end of the file + fp.write("\n") + + def start_monitor(self, *args, **kwargs): + self.start_time = time.time() + self.internal_ts = datetime.datetime.now().strftime(self._ts_format) + + def stop_monitor(self): + end = time.time() + class_name = self.object.__class__.__name__ + func_name = self.tag + + new_data = vars(self.object) + # Filter out multiprocessing which cannot be stored without causing RuntimeError + new_data = self._filter_out_object(new_data) + # We remove stuff we do not want (attribute of the calling class captured by vars()) + if self.record != []: + new_data = self._filter(new_data, self.record) + # We inject some data we want to record + new_data["amsmonitor_duration"] = end - self.start_time + self._update_db(new_data, class_name, func_name, self.internal_ts) + + # We reinitialize some variables + self.start_time = 0 + self.internal_ts = 0 + + def __call__(self, func: Callable): + """ + The main decorator. + """ + + def wrapper(*args, **kwargs): + ts = datetime.datetime.now().strftime(self._ts_format) + start = time.time() + value = func(*args, **kwargs) + end = time.time() + if not hasattr(args[0], "__dict__"): + return value + class_name = args[0].__class__.__name__ + func_name = self.tag if self.tag else func.__name__ + new_data = vars(args[0]) + + # Filter out multiprocessing which cannot be stored without causing RuntimeError + new_data = self._filter_out_object(new_data) + + # We remove stuff we do not want (attribute of the calling class captured by vars()) + new_data = self._filter(new_data, self.record) + new_data["amsmonitor_duration"] = end - start + self._update_db(new_data, class_name, func_name, ts) + return value + + return wrapper + + def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str): + """ + This function update the hashmap containing all the records. + """ + self.lock() + if class_name not in AMSMonitor._stats: + AMSMonitor._stats[class_name] = {} + + if func_name not in AMSMonitor._stats[class_name]: + temp = AMSMonitor._stats[class_name] + temp.update({func_name: {}}) + AMSMonitor._stats[class_name] = temp + temp = AMSMonitor._stats[class_name] + + # We accumulate for each class with a different name + if self.accumulate and temp[func_name] != {}: + ts = self._get_ts(class_name, func_name) + temp[func_name][ts] = self._acc(temp[func_name][ts], new_data) + else: + temp[func_name][ts] = {} + for k, v in new_data.items(): + temp[func_name][ts][k] = v + # This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory) + AMSMonitor._stats[class_name] = temp + self.unlock() + + def _remove_reserved_keys(self, d: Union[dict, List]) -> dict: + for key in self._reserved_keys: + if key in d: + self.logger.warning(f"attribute {key} is protected and will be ignored ({d})") + if isinstance(d, list): + idx = d.index(key) + d.pop(idx) + elif isinstance(d, dict): + del d[key] + return d + + def _acc(self, original: dict, new_data: dict) -> dict: + """ + Sum up element-wise two hashmaps (ignore fields that are not common) + """ + for k, v in new_data.items(): + # We accumalate variable internally managed by AMSMonitor (duration etc) + if k in AMSMonitor._reserved_keys: + original[k] = float(original[k]) + float(v) + else: + original[k] = v + return original + + def _filter_out_object(self, data: dict) -> dict: + """ + Filter out a hashmap to remove objects which can cause errors + """ + + def is_serializable(x): + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + + new_dict = {k: v for k, v in data.items() if is_serializable(v)} + + return new_dict + + def _filter(self, data: dict, keys: List[str]) -> dict: + """ + Filter out a hashmap to contains only keys listed by list of keys + """ + if not self.record: + return data + return {k: v for k, v in data.items() if k in keys} + + def _get_ts(self, class_name: str, tag: str) -> str: + """ + Return initial timestamp for a given monitored function. + """ + ts = datetime.datetime.now().strftime(self._ts_format) + if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]: + return ts + + init_ts = list(AMSMonitor._stats[class_name][tag].keys()) + if len(init_ts) > 1: + self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}") + return ts if init_ts == [] else init_ts[0] diff --git a/src/AMSWorkflow/ams/rmq.py b/src/AMSWorkflow/ams/rmq.py index f6c7b3e6..3461b8d3 100644 --- a/src/AMSWorkflow/ams/rmq.py +++ b/src/AMSWorkflow/ams/rmq.py @@ -3,21 +3,156 @@ # # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools import logging import ssl +import struct import traceback +from typing import Callable, Tuple +import numpy as np import pika -class RMQChannel: +class AMSMessage(object): """ - A wrapper around RMQ channel + Represents a RabbitMQ incoming message from AMSLib. + + Attributes: + body: The body of the message as received from RabbitMQ + """ + + def __init__(self, body: str): + self.body = body + + def header_format(self) -> str: + """ + This string represents the AMS format in Python pack format: + See https://docs.python.org/3/library/struct.html#format-characters + - 1 byte is the size of the header (here 12). Limit max: 255 + - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 + - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 + - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 + - 2 bytes are the input dimension. Limit max: 65535 + - 2 bytes are the output dimension. Limit max: 65535 + - 4 bytes are for aligning memory to 8 + + |__Header_size__|__Datatype__|__Rank__|__#elem__|__InDim__|__OutDim__|...real data...| + + Then the data starts at 12 and is structered as pairs of input/outputs. + Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double): + + |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + + """ + return "BBHIHHI" + + def endianness(self) -> str: + """ + '=' means native endianness in standart size (system). + See https://docs.python.org/3/library/struct.html#format-characters + """ + return "=" + + def encode(self, num_elem: int, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes: + """ + For debugging and testing purposes, this function encode a message identical to what AMS would send + """ + header_format = self.endianness() + self.header_format() + hsize = struct.calcsize(header_format) + assert dtype_byte in [4, 8] + dt = "f" if dtype_byte == 4 else "d" + mpi_rank = 0 + data = np.random.rand(num_elem * (input_dim + output_dim)) + header_content = (hsize, dtype_byte, mpi_rank, data.size, input_dim, output_dim) + # float or double + msg_format = f"{header_format}{data.size}{dt}" + return struct.pack(msg_format, *header_content, *data) + + def _parse_header(self, body: str) -> dict: + """ + Parse the header to extract information about data. + """ + fmt = self.endianness() + self.header_format() + if len(body) == 0: + print("Empty message. skipping") + return {} + + hsize = struct.calcsize(fmt) + res = {} + # Parse header + ( + res["hsize"], + res["datatype"], + res["mpirank"], + res["num_element"], + res["input_dim"], + res["output_dim"], + res["padding"], + ) = struct.unpack(fmt, body[:hsize]) + assert hsize == res["hsize"] + assert res["datatype"] in [4, 8] + if len(body) < hsize: + print(f"Incomplete message of size {len(body)}. Header should be of size {hsize}. skipping") + return {} + + # Theoritical size in Bytes for the incoming message (without the header) + # Int() is needed otherwise we might overflow here (because of uint16 / uint8) + res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"])) + res["msg_size"] = hsize + res["dsize"] + res["multiple_msg"] = len(body) != res["msg_size"] + return res + + def _parse_data(self, body: str, header_info: dict) -> np.array: + data = np.array([]) + if len(body) == 0: + return data + hsize = header_info["hsize"] + dsize = header_info["dsize"] + try: + if header_info["datatype"] == 4: # if datatype takes 4 bytes (float) + data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float32) + else: + data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float64) + except ValueError as e: + print(f"Error: {e} => {header_info}") + return np.array([]) + + idim = header_info["input_dim"] + odim = header_info["output_dim"] + data = data.reshape((-1, idim + odim)) + # Return input, output + return data[:, :idim], data[:, idim:] + + def _decode(self, body: str) -> Tuple[np.array]: + input = [] + output = [] + # Multiple AMS messages could be packed in one RMQ message + while body: + header_info = self._parse_header(body) + temp_input, temp_output = self._parse_data(body, header_info) + print(f"input shape {temp_input.shape} outpute shape {temp_output.shape}") + # total size of byte we read for that message + chunk_size = header_info["hsize"] + header_info["dsize"] + input.append(temp_input) + output.append(temp_output) + # We remove the current message and keep going + body = body[chunk_size:] + return np.concatenate(input), np.concatenate(output) + + def decode(self) -> Tuple[np.array]: + return self._decode(self.body) + + +class AMSChannel: + """ + A wrapper around Pika RabbitMQ channel """ - def __init__(self, connection, q_name): + def __init__(self, connection, q_name, logger: logging.Logger = None): self.connection = connection self.q_name = q_name + self.logger = logger if logger else logging.getLogger(__name__) def __enter__(self): self.open() @@ -61,7 +196,7 @@ def receive(self, n_msg: int = None, accum_msg=list()): # Call the call on the message parts try: accum_msg.append( - RMQClient.callback( + BlockingClient.callback( method_frame, properties, body, @@ -101,9 +236,9 @@ def purge(self): self.channel.queue_purge(self.q_name) -class RMQClient: +class BlockingClient: """ - RMQClient is a class that manages the RMQ client lifecycle. + BlockingClient is a class that manages a simple blocking RMQ client lifecycle. """ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None): @@ -111,8 +246,9 @@ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logg # openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt self.logger = logger if logger else logging.getLogger(__name__) self.cert = cert - self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.context.verify_mode = ssl.CERT_REQUIRED + self.context.check_hostname = False self.context.load_verify_locations(self.cert) self.host = host self.vhost = vhost @@ -139,4 +275,381 @@ def __exit__(self, exc_type, exc_val, exc_tb): def connect(self, queue): """Connect to the queue""" - return RMQChannel(self.connection, queue) + return AMSChannel(self.connection, queue) + + +class AsyncConsumer(object): + """ + Asynchronous RMQ consumer. AsyncConsumer handles unexpected interactions + with RabbitMQ such as channel and connection closures. AsyncConsumer can + receive messages but cannot send messages. + """ + + def __init__( + self, + host: str, + port: str, + vhost: str, + user: str, + password: str, + cert: str, + queue: str, + prefetch_count: int = 1, + on_message_cb: Callable = None, + on_close_cb: Callable = None, + logger: logging.Logger = None, + ): + """Create a new instance of the consumer class, passing in the AMQP + URL used to connect to RabbitMQ. + + :param str credentials: The credentials file in JSON + :param str cacert: The TLS certificate + :param str queue: The queue to listen to + :param Callable: on_message_cb this function will be called each time Pika receive a message + :param Callable: on_message_cb this function will be called when Pika will close the connection + :param int: prefetch_count Define consumer throughput, should be relative to resource and number of messages expected + + """ + self._user = user + self._passwd = password + self._host = host + self._port = port + self._vhost = vhost + self._cacert = cert + self._queue = queue + + self.should_reconnect = False + # Holds the latest error/reason to reconnect + # Could be a Tuple like (200, 'Normal shutdown') or an exception from pika.AMQPError + self.reconnect_reason = None + self.was_consuming = False + self.logger = logger if logger else logging.getLogger(__name__) + + self._connection = None + self._connection_parameters = None + self._channel = None + self._closing = False + self._consumer_tag = None + self._consuming = False + self._prefetch_count = prefetch_count + self._on_message_cb = on_message_cb + self._on_close_cb = on_close_cb + + def __enter__(self): + self.run() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def connection_params(self): + """ + Create the pika credentials using TLS needed to connect to RabbitMQ. + + :rtype: pika.ConnectionParameters + """ + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.verify_mode = ssl.CERT_REQUIRED + ssl_context.check_hostname = False + ssl_context.load_verify_locations(self._cacert) + + pika_credentials = pika.PlainCredentials(self._user, self._passwd) + return pika.ConnectionParameters( + host=self._host, + port=self._port, + virtual_host=self._vhost, + credentials=pika_credentials, + ssl_options=pika.SSLOptions(ssl_context), + ) + + def connect(self): + """This method connects to RabbitMQ, returning the connection handle. + When the connection is established, the on_connection_open method + will be invoked by pika. + + :rtype: pika.SelectConnection + + """ + self._connection_parameters = self.connection_params() + self.logger.debug(f"Connecting to {self._host}") + + return pika.SelectConnection( + parameters=self._connection_parameters, + on_open_callback=self.on_connection_open, + on_open_error_callback=self.on_connection_open_error, + on_close_callback=self.on_connection_closed, + ) + + def close_connection(self): + self._consuming = False + if self._connection.is_closing or self._connection.is_closed: + self.logger.debug("Connection is closing or already closed") + else: + self.logger.debug("Closing connection") + self._connection.close() + + def on_connection_open(self, connection): + """This method is called by pika once the connection to RabbitMQ has + been established. It passes the handle to the connection object in + case we need it, but in this case, we'll just mark it unused. + + :param pika.SelectConnection _unused_connection: The connection + + """ + assert self._connection is connection + self.logger.debug("Connection opened") + self.open_channel() + + def on_connection_open_error(self, _unused_connection, err): + """This method is called by pika if the connection to RabbitMQ + can't be established. + + :param pika.SelectConnection _unused_connection: The connection + :param Exception err: The error + + """ + self.logger.error(f"Connection open failed: {err}") + self.reconnect_reason = err + self.reconnect() + + def on_connection_closed(self, _unused_connection, reason): + """This method is invoked by pika when the connection to RabbitMQ is + closed unexpectedly. Since it is unexpected, we will reconnect to + RabbitMQ if it disconnects. + + :param pika.connection.Connection connection: The closed connection obj + :param Exception reason: exception representing reason for loss of + connection. + + """ + self._channel = None + if self._closing: + self._connection.ioloop.stop() + else: + self.logger.warning(f"Connection closed, reconnect necessary: {reason}") + self.reconnect_reason = reason + self.reconnect() + + def reconnect(self): + """Will be invoked if the connection can't be opened or is + closed. Indicates that a reconnect is necessary then stops the + ioloop. + + """ + self.should_reconnect = True + self.stop() + + def open_channel(self): + """Open a new channel with RabbitMQ by issuing the Channel.Open RPC + command. When RabbitMQ responds that the channel is open, the + on_channel_open callback will be invoked by pika. + + """ + self.logger.debug("Creating a new channel") + self._connection.channel(on_open_callback=self.on_channel_open) + + def on_channel_open(self, channel): + """This method is invoked by pika when the channel has been opened. + The channel object is passed in so we can make use of it. + + Since the channel is now open, we'll declare the exchange to use. + + :param pika.channel.Channel channel: The channel object + + """ + self._channel = channel + self.logger.debug("Channel opened") + self.add_on_channel_close_callback() + # we do not set up exchange first here, we use the default exchange '' + self.setup_queue(self._queue) + + def add_on_channel_close_callback(self): + """This method tells pika to call the on_channel_closed method if + RabbitMQ unexpectedly closes the channel. + + """ + self.logger.debug("Adding channel close callback") + self._channel.add_on_close_callback(self.on_channel_closed) + + def on_channel_closed(self, channel, reason): + """Invoked by pika when RabbitMQ unexpectedly closes the channel. + Channels are usually closed if you attempt to do something that + violates the protocol, such as re-declare an exchange or queue with + different parameters. In this case, we'll close the connection + to shutdown the object. + + :param pika.channel.Channel: The closed channel + :param Exception reason: why the channel was closed + + """ + self.logger.debug(f"Channel was closed. {reason}") + if isinstance(self._on_close_cb, Callable): + self._on_close_cb() # running user callback + self.close_connection() + + def setup_queue(self, queue_name): + """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC + command. When it is complete, the on_queue_declareok method will + be invoked by pika. + + :param str|unicode queue_name: The name of the queue to declare. + + """ + self.logger.debug(f'Declaring queue "{queue_name}"') + cb = functools.partial(self.on_queue_declareok, userdata=queue_name) + self._channel.queue_declare(queue=queue_name, exclusive=False, callback=cb) + + def on_queue_declareok(self, _unused_frame, userdata): + """Method invoked by pika when the Queue.Declare RPC call made in + setup_queue has completed. In this method we will bind the queue + and exchange together with the routing key by issuing the Queue.Bind + RPC command. When this command is complete, the on_bindok method will + be invoked by pika. + + :param pika.frame.Method _unused_frame: The Queue.DeclareOk frame + :param str|unicode userdata: Extra user data (queue name) + + """ + queue_name = userdata + self.logger.debug(f'Queue "{queue_name}" declared') + self.set_qos() + + def set_qos(self): + """This method sets up the consumer prefetch to only be delivered + one message at a time. The consumer must acknowledge this message + before RabbitMQ will deliver another one. You should experiment + with different prefetch values to achieve desired performance. + + """ + self._channel.basic_qos(prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok) + + def on_basic_qos_ok(self, _unused_frame): + """Invoked by pika when the Basic.QoS method has completed. At this + point we will start consuming messages by calling start_consuming + which will invoke the needed RPC commands to start the process. + + :param pika.frame.Method _unused_frame: The Basic.QosOk response frame + + """ + self.logger.debug(f"QOS set to: {self._prefetch_count}") + self.start_consuming() + + def start_consuming(self): + """This method sets up the consumer by first calling + add_on_cancel_callback so that the object is notified if RabbitMQ + cancels the consumer. It then issues the Basic.Consume RPC command + which returns the consumer tag that is used to uniquely identify the + consumer with RabbitMQ. We keep the value to use it when we want to + cancel consuming. The on_message method is passed in as a callback pika + will invoke when a message is fully received. + + """ + self.logger.debug("Issuing consumer related RPC commands") + self.add_on_cancel_callback() + self._consumer_tag = self._channel.basic_consume(self._queue, self.on_message, auto_ack=False) + self.was_consuming = True + self._consuming = True + self.logger.info(f"Waiting for messages (tag: {self._consumer_tag}). To exit press CTRL+C") + + def add_on_cancel_callback(self): + """Add a callback that will be invoked if RabbitMQ cancels the consumer + for some reason. If RabbitMQ does cancel the consumer, + on_consumer_cancelled will be invoked by pika. + + """ + self.logger.debug("Adding consumer cancellation callback") + self._channel.add_on_cancel_callback(self.on_consumer_cancelled) + + def on_consumer_cancelled(self, method_frame): + """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer + receiving messages. + + :param pika.frame.Method method_frame: The Basic.Cancel frame + + """ + self.logger.debug(f"Consumer was cancelled remotely, shutting down: {method_frame}") + if self._channel: + self._channel.close() + + def on_message(self, _unused_channel, basic_deliver, properties, body): + """Invoked by pika when a message is delivered from RabbitMQ. The + channel is passed for your convenience. The basic_deliver object that + is passed in carries the exchange, routing key, delivery tag and + a redelivered flag for the message. The properties passed in is an + instance of BasicProperties with the message properties and the body + is the message that was sent. + + :param pika.channel.Channel _unused_channel: The channel object + :param pika.Spec.Basic.Deliver: basic_deliver method + :param pika.Spec.BasicProperties: properties + :param bytes body: The message body + + """ + self.logger.info(f"Received message #{basic_deliver.delivery_tag} from {properties}") + if isinstance(self._on_message_cb, Callable): + self._on_message_cb(_unused_channel, basic_deliver, properties, body) + self.acknowledge_message(basic_deliver.delivery_tag) + + def acknowledge_message(self, delivery_tag): + """Acknowledge the message delivery from RabbitMQ by sending a + Basic.Ack RPC method for the delivery tag. + + :param int delivery_tag: The delivery tag from the Basic.Deliver frame + + """ + self.logger.debug(f"Acknowledging message {delivery_tag}") + self._channel.basic_ack(delivery_tag) + + def stop_consuming(self): + """Tell RabbitMQ that you would like to stop consuming by sending the + Basic.Cancel RPC command. + + """ + if self._channel: + self.logger.debug("Sending a Basic.Cancel RPC command to RabbitMQ") + cb = functools.partial(self.on_cancelok, userdata=self._consumer_tag) + self._channel.basic_cancel(self._consumer_tag, cb) + + def on_cancelok(self, _unused_frame, userdata): + """This method is invoked by pika when RabbitMQ acknowledges the + cancellation of a consumer. At this point we will close the channel. + This will invoke the on_channel_closed method once the channel has been + closed, which will in-turn close the connection. + + :param pika.frame.Method _unused_frame: The Basic.CancelOk frame + :param str|unicode userdata: Extra user data (consumer tag) + + """ + self._consuming = False + self.logger.debug(f"RabbitMQ acknowledged the cancellation of the consumer: {userdata}") + self.close_channel() + + def close_channel(self): + """Call to close the channel with RabbitMQ cleanly by issuing the + Channel.Close RPC command. + """ + self.logger.debug("Closing the channel") + self._channel.close() + + def run(self): + """Run the example consumer by connecting to RabbitMQ and then + starting the IOLoop to block and allow the SelectConnection to operate. + """ + self._connection = self.connect() + self._connection.ioloop.start() + + def stop(self): + """Cleanly shutdown the connection to RabbitMQ by stopping the consumer + with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok + will be invoked by pika, which will then closing the channel and + connection. + """ + if not self._closing: + self._closing = True + self.logger.debug(" Stopping RabbitMQ connection") + if self._consuming: + self.stop_consuming() + else: + if self._connection: + self._connection.ioloop.stop() + self.logger.debug("Stopped RabbitMQ connection") diff --git a/src/AMSWorkflow/ams/rmq_async.py b/src/AMSWorkflow/ams/rmq_async.py deleted file mode 100644 index 65209520..00000000 --- a/src/AMSWorkflow/ams/rmq_async.py +++ /dev/null @@ -1,401 +0,0 @@ -# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other -# AMSLib Project Developers -# -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import ssl -import sys -import os -import json -import logging -import re -import copy -import functools -import pika -from pika.exchange_type import ExchangeType -from typing import Callable -import numpy as np - -class RMQConsumer(object): - """ - Asynchronous RMQ consumer. - RMQConsumer handles unexpected interactions - with RabbitMQ such as channel and connection closures. - """ - - def __init__(self, - credentials: str, - cacert: str, - queue: str, - on_message_cb: Callable = None, - on_close_cb: Callable = None, - prefetch_count: int = 1): - """Create a new instance of the consumer class, passing in the AMQP - URL used to connect to RabbitMQ. - - :param str credentials: The credentials file in JSON - :param str cacert: The TLS certificate - :param str queue: The queue to listen to - :param Callable: on_message_cb this function will be called each time Pika receive a message - :param Callable: on_message_cb this function will be called when Pika will close the connection - :param int: prefetch_count Define consumer throughput, should be relative to resource and number of messages expected - - """ - self.should_reconnect = False - # Holds the latest error/reason to reconnect - # Could be a Tuple like (200, 'Normal shutdown') or an exception from pika.AMQPError - self.reconnect_reason = None - self.was_consuming = False - - self._connection = None - self._connection_parameters = None - self._channel = None - self._closing = False - self._consumer_tag = None - self._consuming = False - self._prefetch_count = prefetch_count - self._on_message_cb = on_message_cb - self._on_close_cb = on_close_cb - - self._credentials = self._parse_credentials(credentials) - self._cacert = cacert - self._queue = queue - - def __enter__(self): - self.run() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - def _parse_credentials(self, json_file: str) -> dict: - """ Internal method to parse the credentials file""" - data = {} - with open(json_file, 'r') as f: - data = json.load(f) - return data - - def create_credentials(self): - """ - Create the pika credentials using TLS needed to connect to RabbitMQ. - - :rtype: pika.ConnectionParameters - - """ - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - ssl_context.verify_mode = ssl.CERT_REQUIRED - ssl_context.load_verify_locations(self._cacert) - - pika_credentials = pika.PlainCredentials(self._credentials["rabbitmq-user"], self._credentials["rabbitmq-password"]) - return pika.ConnectionParameters( - host=self._credentials["service-host"], - port=self._credentials["service-port"], - virtual_host=self._credentials["rabbitmq-vhost"], - credentials=pika_credentials, - ssl_options=pika.SSLOptions(ssl_context) - ) - - def connect(self): - """This method connects to RabbitMQ, returning the connection handle. - When the connection is established, the on_connection_open method - will be invoked by pika. - - :rtype: pika.SelectConnection - - """ - self._connection_parameters = self.create_credentials() - print(f"Connecting to {self._credentials['service-host']}") - - return pika.SelectConnection( - parameters = self._connection_parameters, - on_open_callback = self.on_connection_open, - on_open_error_callback = self.on_connection_open_error, - on_close_callback = self.on_connection_closed) - - def close_connection(self): - self._consuming = False - if self._connection.is_closing or self._connection.is_closed: - print("Connection is closing or already closed") - else: - print("Closing connection") - self._connection.close() - - def on_connection_open(self, _unused_connection): - """This method is called by pika once the connection to RabbitMQ has - been established. It passes the handle to the connection object in - case we need it, but in this case, we'll just mark it unused. - - :param pika.SelectConnection _unused_connection: The connection - - """ - print("Connection opened") - self.open_channel() - - def on_connection_open_error(self, _unused_connection, err): - """This method is called by pika if the connection to RabbitMQ - can't be established. - - :param pika.SelectConnection _unused_connection: The connection - :param Exception err: The error - - """ - print(f"Error: Connection open failed: {err}") - self.reconnect_reason = err - self.reconnect() - - def on_connection_closed(self, _unused_connection, reason): - """This method is invoked by pika when the connection to RabbitMQ is - closed unexpectedly. Since it is unexpected, we will reconnect to - RabbitMQ if it disconnects. - - :param pika.connection.Connection connection: The closed connection obj - :param Exception reason: exception representing reason for loss of - connection. - - """ - self._channel = None - if self._closing: - self._connection.ioloop.stop() - else: - print(f"warning: Connection closed, reconnect necessary: {reason}") - self.reconnect_reason = reason - self.reconnect() - - def reconnect(self): - """Will be invoked if the connection can't be opened or is - closed. Indicates that a reconnect is necessary then stops the - ioloop. - - """ - self.should_reconnect = True - self.stop() - - def open_channel(self): - """Open a new channel with RabbitMQ by issuing the Channel.Open RPC - command. When RabbitMQ responds that the channel is open, the - on_channel_open callback will be invoked by pika. - - """ - print("Creating a new channel") - self._connection.channel(on_open_callback = self.on_channel_open) - - def on_channel_open(self, channel): - """This method is invoked by pika when the channel has been opened. - The channel object is passed in so we can make use of it. - - Since the channel is now open, we'll declare the exchange to use. - - :param pika.channel.Channel channel: The channel object - - """ - self._channel = channel - print(f"Channel opened {self._channel}") - self.add_on_channel_close_callback() - # we do not set up exchange first here, we use the default exchange '' - self.setup_queue(self._queue) - - def add_on_channel_close_callback(self): - """This method tells pika to call the on_channel_closed method if - RabbitMQ unexpectedly closes the channel. - - """ - print("Adding channel close callback") - self._channel.add_on_close_callback(self.on_channel_closed) - - def on_channel_closed(self, channel, reason): - """Invoked by pika when RabbitMQ unexpectedly closes the channel. - Channels are usually closed if you attempt to do something that - violates the protocol, such as re-declare an exchange or queue with - different parameters. In this case, we'll close the connection - to shutdown the object. - - :param pika.channel.Channel: The closed channel - :param Exception reason: why the channel was closed - - """ - print(f"warning: Channel {channel} was closed: {reason}") - if isinstance(self._on_close_cb, Callable): - self._on_close_cb() # running user callback - self.close_connection() - - def setup_queue(self, queue_name): - """Setup the queue on RabbitMQ by invoking the Queue.Declare RPC - command. When it is complete, the on_queue_declareok method will - be invoked by pika. - - :param str|unicode queue_name: The name of the queue to declare. - - """ - print(f"Declaring queue \"{queue_name}\"") - cb = functools.partial(self.on_queue_declareok, userdata = queue_name) - self._channel.queue_declare(queue = queue_name, exclusive=False, callback = cb) - - def on_queue_declareok(self, _unused_frame, userdata): - """Method invoked by pika when the Queue.Declare RPC call made in - setup_queue has completed. In this method we will bind the queue - and exchange together with the routing key by issuing the Queue.Bind - RPC command. When this command is complete, the on_bindok method will - be invoked by pika. - - :param pika.frame.Method _unused_frame: The Queue.DeclareOk frame - :param str|unicode userdata: Extra user data (queue name) - - """ - queue_name = userdata - print(f"Queue \"{queue_name}\" declared") - self.set_qos() - - def set_qos(self): - """This method sets up the consumer prefetch to only be delivered - one message at a time. The consumer must acknowledge this message - before RabbitMQ will deliver another one. You should experiment - with different prefetch values to achieve desired performance. - - """ - self._channel.basic_qos( - prefetch_count = self._prefetch_count, - callback = self.on_basic_qos_ok - ) - - def on_basic_qos_ok(self, _unused_frame): - """Invoked by pika when the Basic.QoS method has completed. At this - point we will start consuming messages by calling start_consuming - which will invoke the needed RPC commands to start the process. - - :param pika.frame.Method _unused_frame: The Basic.QosOk response frame - - """ - print(f"QOS set to: {self._prefetch_count}") - self.start_consuming() - - def start_consuming(self): - """This method sets up the consumer by first calling - add_on_cancel_callback so that the object is notified if RabbitMQ - cancels the consumer. It then issues the Basic.Consume RPC command - which returns the consumer tag that is used to uniquely identify the - consumer with RabbitMQ. We keep the value to use it when we want to - cancel consuming. The on_message method is passed in as a callback pika - will invoke when a message is fully received. - - """ - print("Issuing consumer related RPC commands") - self.add_on_cancel_callback() - self._consumer_tag = self._channel.basic_consume( - self._queue, self.on_message, auto_ack=False) - self.was_consuming = True - self._consuming = True - print(" [*] Waiting for messages. To exit press CTRL+C") - - def add_on_cancel_callback(self): - """Add a callback that will be invoked if RabbitMQ cancels the consumer - for some reason. If RabbitMQ does cancel the consumer, - on_consumer_cancelled will be invoked by pika. - - """ - print("Adding consumer cancellation callback") - self._channel.add_on_cancel_callback(self.on_consumer_cancelled) - - def on_consumer_cancelled(self, method_frame): - """Invoked by pika when RabbitMQ sends a Basic.Cancel for a consumer - receiving messages. - - :param pika.frame.Method method_frame: The Basic.Cancel frame - - """ - print(f"Consumer was cancelled remotely, shutting down: {method_frame}") - if self._channel: - self._channel.close() - - def on_message(self, _unused_channel, basic_deliver, properties, body): - """Invoked by pika when a message is delivered from RabbitMQ. The - channel is passed for your convenience. The basic_deliver object that - is passed in carries the exchange, routing key, delivery tag and - a redelivered flag for the message. The properties passed in is an - instance of BasicProperties with the message properties and the body - is the message that was sent. - - :param pika.channel.Channel _unused_channel: The channel object - :param pika.Spec.Basic.Deliver: basic_deliver method - :param pika.Spec.BasicProperties: properties - :param bytes body: The message body - - """ - print(f"Received message #{basic_deliver.delivery_tag} from {properties}") - if isinstance(self._on_message_cb, Callable): - self._on_message_cb(_unused_channel, basic_deliver, properties, body) - self.acknowledge_message(basic_deliver.delivery_tag) - - def acknowledge_message(self, delivery_tag): - """Acknowledge the message delivery from RabbitMQ by sending a - Basic.Ack RPC method for the delivery tag. - - :param int delivery_tag: The delivery tag from the Basic.Deliver frame - - """ - print(f"Acknowledging message {delivery_tag}") - self._channel.basic_ack(delivery_tag) - - def stop_consuming(self): - """Tell RabbitMQ that you would like to stop consuming by sending the - Basic.Cancel RPC command. - - """ - if self._channel: - print(f"Sending a Basic.Cancel RPC command to RabbitMQ") - cb = functools.partial( - self.on_cancelok, userdata = self._consumer_tag) - self._channel.basic_cancel(self._consumer_tag, cb) - - def on_cancelok(self, _unused_frame, userdata): - """This method is invoked by pika when RabbitMQ acknowledges the - cancellation of a consumer. At this point we will close the channel. - This will invoke the on_channel_closed method once the channel has been - closed, which will in-turn close the connection. - - :param pika.frame.Method _unused_frame: The Basic.CancelOk frame - :param str|unicode userdata: Extra user data (consumer tag) - - """ - self._consuming = False - print(f"RabbitMQ acknowledged the cancellation of the consumer: {userdata}") - self.close_channel() - - def close_channel(self): - """Call to close the channel with RabbitMQ cleanly by issuing the - Channel.Close RPC command. - - """ - print("Closing the channel") - self._channel.close() - - def run(self): - """Run the example consumer by connecting to RabbitMQ and then - starting the IOLoop to block and allow the SelectConnection to operate. - - """ - self._connection = self.connect() - self._connection.ioloop.start() - - def stop(self): - """Cleanly shutdown the connection to RabbitMQ by stopping the consumer - with RabbitMQ. When RabbitMQ confirms the cancellation, on_cancelok - will be invoked by pika, which will then closing the channel and - connection. The IOLoop is started again because this method is invoked - when CTRL-C is pressed raising a KeyboardInterrupt exception. This - exception stops the IOLoop which needs to be running for pika to - communicate with RabbitMQ. All of the commands issued prior to starting - the IOLoop will be buffered but not processed. - - """ - if not self._closing: - self._closing = True - print("Stopping RabbitMQ connection") - if self._consuming: - self.stop_consuming() - self._connection.ioloop.start() - else: - if self._connection: - self._connection.ioloop.stop() - print("Stopped RabbitMQ connection") - else: - print("Already closed?") diff --git a/src/AMSWorkflow/ams/stage.py b/src/AMSWorkflow/ams/stage.py index bd12a628..f3770bf8 100644 --- a/src/AMSWorkflow/ams/stage.py +++ b/src/AMSWorkflow/ams/stage.py @@ -5,26 +5,27 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import glob +import json +import os import shutil +import signal import time from abc import ABC, abstractclassmethod, abstractmethod from enum import Enum -from multiprocessing import Process, current_process +from multiprocessing import Process from multiprocessing import Queue as mp_queue from pathlib import Path from queue import Queue as ser_queue from threading import Thread -from typing import Callable, List, Tuple -import struct -import signal +from typing import Callable import numpy as np - from ams.config import AMSInstance from ams.faccessors import get_reader, get_writer +from ams.monitor import AMSMonitor +from ams.rmq import AMSMessage, AsyncConsumer from ams.store import AMSDataStore from ams.util import get_unique_fn -from ams.rmq_async import RMQConsumer BATCH_SIZE = 32 * 1024 * 1024 @@ -112,13 +113,13 @@ def __init__(self, i_queue, o_queue, callback): """ initializes a ForwardTask class with the queues and the callback. """ - if not isinstance(callback, Callable): raise TypeError(f"{callback} argument is not Callable") self.i_queue = i_queue self.o_queue = o_queue self.callback = callback + self.datasize = 0 def _action(self, data): """ @@ -136,6 +137,7 @@ def _action(self, data): raise TypeError(f"{self.callback.__name__} did not return numpy arrays") return inputs, outputs + @AMSMonitor(record=["datasize"]) def __call__(self): """ A busy loop reading messages from the i_queue, acting on those messages and forwarding @@ -152,6 +154,7 @@ def __call__(self): elif item.is_process(): inputs, outputs = self._action(item.data()) self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs))) + self.datasize += inputs.nbytes + outputs.nbytes elif item.is_new_model(): # This is not handled yet continue @@ -174,7 +177,9 @@ def __init__(self, o_queue, loader, pattern): self.o_queue = o_queue self.pattern = pattern self.loader = loader + self.datasize = 0 + @AMSMonitor(record=["datasize"]) def __call__(self): """ Busy loop of reading all files matching the pattern and creating @@ -193,142 +198,13 @@ def __call__(self): output_batches = np.array_split(output_data, num_batches) for j, (i, o) in enumerate(zip(input_batches, output_batches)): self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o))) + self.datasize += input_data.nbytes + output_data.nbytes self.o_queue.put(QueueMessage(MessageType.Terminate, None)) end = time.time() print(f"Spend {end - start} at {self.__class__.__name__}") -class RMQMessage(object): - """ - Represents a RabbitMQ incoming message from AMSLib. - - Attributes: - body: The body of the message as received from RabbitMQ - """ - - def __init__(self, body: str): - self.body = body - - def header_format(self) -> str: - """ - This string represents the AMS format in Python pack format: - See https://docs.python.org/3/library/struct.html#format-characters - - 1 byte is the size of the header (here 12). Limit max: 255 - - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 - - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 - - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 - - 2 bytes are the input dimension. Limit max: 65535 - - 2 bytes are the output dimension. Limit max: 65535 - - 4 bytes are for aligning memory to 8 - - |__Header_size__|__Datatype__|__Rank__|__#elem__|__InDim__|__OutDim__|...real data...| - - Then the data starts at 12 and is structered as pairs of input/outputs. - Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double): - - |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| - - """ - return "BBHIHHI" - - def endianness(self) -> str: - """ - '=' means native endianness in standart size (system). - See https://docs.python.org/3/library/struct.html#format-characters - """ - return "=" - - def encode(num_elem: int, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes: - """ - For debugging and testing purposes, this function encode a message identical to what AMS would send - """ - header_format = self.endianness() + self.header_format() - hsize = struct.calcsize(header_format) - assert dtype_byte in [4, 8] - dt = "f" if dtype_byte == 4 else "d" - mpi_rank = 0 - data = np.random.rand(num_elem * (input_dim + output_dim)) - header_content = (hsize, dtype_byte, mpi_rank, data.size, input_dim, output_dim) - # float or double - msg_format = f"{header_format}{data.size}{dt}" - return struct.pack(msg_format, *header_content, *data) - - def _parse_header(self, body: str) -> dict: - """ - Parse the header to extract information about data. - """ - fmt = self.endianness() + self.header_format() - if len(body) == 0: - print(f"Empty message. skipping") - return {} - - hsize = struct.calcsize(fmt) - res = {} - # Parse header - ( - res["hsize"], - res["datatype"], - res["mpirank"], - res["num_element"], - res["input_dim"], - res["output_dim"], - res["padding"], - ) = struct.unpack(fmt, body[:hsize]) - assert hsize == res["hsize"] - assert res["datatype"] in [4, 8] - if len(body) < hsize: - print(f"Incomplete message of size {len(body)}. Header should be of size {hsize}. skipping") - return {} - - # Theoritical size in Bytes for the incoming message (without the header) - # Int() is needed otherwise we might overflow here (because of uint16 / uint8) - res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"])) - res["msg_size"] = hsize + res["dsize"] - res["multiple_msg"] = len(body) != res["msg_size"] - return res - - def _parse_data(self, body: str, header_info: dict) -> np.array: - data = np.array([]) - if len(body) == 0: - return data - hsize = header_info["hsize"] - dsize = header_info["dsize"] - try: - if header_info["datatype"] == 4: # if datatype takes 4 bytes (float) - data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float32) - else: - data = np.frombuffer(body[hsize : hsize + dsize], dtype=np.float64) - except ValueError as e: - print(f"Error: {e} => {header_info}") - return np.array([]) - - idim = header_info["input_dim"] - odim = header_info["output_dim"] - data = data.reshape((-1, idim + odim)) - # Return input, output - return data[:, :idim], data[:, idim:] - - def _decode(self, body: str) -> Tuple[np.array]: - input = [] - output = [] - # Multiple AMS messages could be packed in one RMQ message - while body: - header_info = self._parse_header(body) - temp_input, temp_output = self._parse_data(body, header_info) - print(f"input shape {temp_input.shape} outpute shape {temp_output.shape}") - # total size of byte we read for that message - chunk_size = header_info["hsize"] + header_info["dsize"] - input.append(temp_input) - output.append(temp_output) - # We remove the current message and keep going - body = body[chunk_size:] - return np.concatenate(input), np.concatenate(output) - - def decode(self) -> Tuple[np.array]: - return self._decode(self.body) - - class RMQLoaderTask(Task): """ A RMQLoaderTask consumes data from RabbitMQ bundles the data of @@ -343,23 +219,45 @@ class RMQLoaderTask(Task): prefetch_count: Number of messages prefected by RMQ (impact performance) """ - def __init__(self, o_queue, credentials, cacert, rmq_queue, prefetch_count=1): + def __init__( + self, + o_queue, + host, + port, + vhost, + user, + password, + cert, + rmq_queue, + policy, + prefetch_count=1, + signals=[signal.SIGTERM, signal.SIGINT, signal.SIGUSR1], + ): self.o_queue = o_queue - self.credentials = credentials - self.cacert = cacert + self.cert = cert self.rmq_queue = rmq_queue self.prefetch_count = prefetch_count - - # Installing signal callbacks - p = current_process() - print(f"pid = {p.pid}") - signal.signal(signal.SIGTERM, self.signal_wrapper(self.__class__.__name__, p.pid)) - signal.signal(signal.SIGINT, self.signal_wrapper(self.__class__.__name__, p.pid)) + self.datasize = 0 self.total_time = 0 - - self.rmq_consumer = RMQConsumer( - credentials=self.credentials, - cacert=self.cacert, + self.signals = signals + self.orig_sig_handlers = {} + self.policy = policy + + # Signals can only be used within the main thread + if self.policy != "thread": + # We ignore SIGTERM, SIGUSR1, SIGINT by default so later + # we can override that handler only for RMQLoaderTask + for s in self.signals: + self.orig_sig_handlers[s] = signal.getsignal(s) + signal.signal(s, signal.SIG_IGN) + + self.rmq_consumer = AsyncConsumer( + host=host, + port=port, + vhost=vhost, + user=user, + password=password, + cert=self.cert, queue=self.rmq_queue, on_message_cb=self.callback_message, on_close_cb=self.callback_close, @@ -371,7 +269,6 @@ def callback_close(self): Callback that will be called when RabbitMQ will close the connection (or if a problem happened with the connection). """ - print(f"Sending Terminate to QueueMessage") self.o_queue.put(QueueMessage(MessageType.Terminate, None)) def callback_message(self, ch, basic_deliver, properties, body): @@ -380,13 +277,15 @@ def callback_message(self, ch, basic_deliver, properties, body): the connection (or if a problem happened with the connection). """ start_time = time.time() - input_data, output_data = RMQMessage(body).decode() + input_data, output_data = AMSMessage(body).decode() row_size = input_data[0, :].nbytes + output_data[0, :].nbytes rows_per_batch = int(np.ceil(BATCH_SIZE / row_size)) num_batches = int(np.ceil(input_data.shape[0] / rows_per_batch)) input_batches = np.array_split(input_data, num_batches) output_batches = np.array_split(output_data, num_batches) + self.datasize += input_data.nbytes + output_data.nbytes + for j, (i, o) in enumerate(zip(input_batches, output_batches)): self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o))) @@ -395,18 +294,24 @@ def callback_message(self, ch, basic_deliver, properties, body): def signal_wrapper(self, name, pid): def handler(signum, frame): print(f"Received SIGNUM={signum} for {name}[pid={pid}]: stopping process") - self.rmq_consumer.stop() - self.o_queue.put(QueueMessage(MessageType.Terminate, None)) - print(f"Spend {self.total_time} at {self.__class__.__name__}") + self.stop() return handler + def stop(self): + self.rmq_consumer.stop() + self.o_queue.put(QueueMessage(MessageType.Terminate, None)) + print(f"Spend {self.total_time} at {self.__class__.__name__}") + + @AMSMonitor(record=["datasize", "total_time"]) def __call__(self): """ - Busy loop of reading all files matching the pattern and creating - '100' batches which will be pushed on the queue. Upon reading all files - the Task pushes a 'Terminate' message to the queue and returns. + Busy loop of consuming messages from RMQ queue """ + # Installing signal callbacks only for RMQLoaderTask + if self.policy != "thread": + for s in self.signals: + signal.signal(s, self.signal_wrapper(self.__class__.__name__, os.getpid())) self.rmq_consumer.run() @@ -432,6 +337,7 @@ def __init__(self, i_queue, o_queue, writer_cls, out_dir): self.o_queue = o_queue self.suffix = writer_cls.get_file_format_suffix() + @AMSMonitor(record=["datasize"]) def __call__(self): """ A busy loop reading messages from the i_queue, writting the input,output data in a file @@ -447,22 +353,23 @@ def __call__(self): total_bytes_written = 0 with self.data_writer_cls(fn) as fd: bytes_written = 0 - while True: - # This is a blocking call - item = self.i_queue.get(block=True) - if item.is_terminate(): - is_terminate = True - elif item.is_process(): - data = item.data() - bytes_written += data.inputs.size * data.inputs.itemsize - bytes_written += data.outputs.size * data.outputs.itemsize - fd.store(data.inputs, data.outputs) - total_bytes_written += data.inputs.size * data.inputs.itemsize - total_bytes_written += data.outputs.size * data.outputs.itemsize - # FIXME: We currently decide to chunk files to 2GB - # of contents. Is this a good size? - if is_terminate or bytes_written >= 2 * 1024 * 1024 * 1024: - break + with AMSMonitor(obj=self, tag="internal_loop", accumulate=False): + while True: + # This is a blocking call + item = self.i_queue.get(block=True) + if item.is_terminate(): + is_terminate = True + elif item.is_process(): + data = item.data() + bytes_written += data.inputs.size * data.inputs.itemsize + bytes_written += data.outputs.size * data.outputs.itemsize + fd.store(data.inputs, data.outputs) + total_bytes_written += data.inputs.size * data.inputs.itemsize + total_bytes_written += data.outputs.size * data.outputs.itemsize + # FIXME: We currently decide to chunk files to 2GB + # of contents. Is this a good size? + if is_terminate or bytes_written >= 2 * 1024 * 1024 * 1024: + break self.o_queue.put(QueueMessage(MessageType.Process, fn)) if is_terminate: @@ -470,6 +377,7 @@ def __call__(self): break end = time.time() + self.datasize = total_bytes_written print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__}") @@ -496,9 +404,12 @@ def __init__(self, i_queue, ams_config, db_path, store): self.i_queue = i_queue self.dir = Path(db_path).absolute() self._store = store + self.nb_requests = 0 + self.total_filesize = 0 if not self.dir.exists(): self.dir.mkdir(parents=True, exist_ok=True) + @AMSMonitor(record=["nb_requests"]) def __call__(self): """ A busy loop reading messages from the i_queue publishing them to the kosh store. @@ -509,18 +420,22 @@ def __call__(self): self.ams_config.db_path, self.ams_config.db_store, self.ams_config.name, False ).open() - while True: - item = self.i_queue.get(block=True) - if item.is_terminate(): - break - elif item.is_process(): - src_fn = Path(item.data()) - dest_file = self.dir / src_fn.name - if src_fn != dest_file: - shutil.move(src_fn, dest_file) - - if self._store: - db_store.add_candidates([str(dest_file)]) + with AMSMonitor(obj=self, tag="internal_loop", record=[]): + while True: + item = self.i_queue.get(block=True) + if item.is_terminate(): + break + elif item.is_process(): + with AMSMonitor(obj=self, tag="request_block", record=["nb_requests", "total_filesize"]): + self.nb_requests += 1 + src_fn = Path(item.data()) + dest_file = self.dir / src_fn.name + if src_fn != dest_file: + shutil.move(src_fn, dest_file) + if self._store: + db_store.add_candidates([str(dest_file)]) + + self.total_filesize += os.stat(src_fn).st_size end = time.time() print(f"Spend {end - start} at {self.__class__.__name__}") @@ -640,7 +555,7 @@ def _link_pipeline(self, policy): num_queues = 1 + len(self.actions) - 1 + 2 self._queues = [_qType() for i in range(num_queues)] - self._tasks = [self.get_load_task(self._queues[0])] + self._tasks = [self.get_load_task(self._queues[0], policy)] for i, a in enumerate(self.actions): self._tasks.append(ForwardTask(self._queues[i], self._queues[i + 1], a)) @@ -667,7 +582,7 @@ def execute(self, policy): self._execute_tasks(policy) @abstractmethod - def get_load_task(self, o_queue): + def get_load_task(self, o_queue, policy): """ Callback to the child class to return the task that loads data from some unspecified entry-point. """ @@ -736,7 +651,7 @@ def __init__(self, db_dir, store, dest_dir, stage_dir, db_type, src, src_type, p self._pattern = pattern self._src_type = src_type - def get_load_task(self, o_queue): + def get_load_task(self, o_queue, policy): """ Return a Task that loads data from the filesystem @@ -781,31 +696,50 @@ class RMQPipeline(Pipeline): A 'Pipeline' reading data from RabbitMQ and storing them back to the filesystem. Attributes: - credentials: The JSON credentials to connect to RMQ Server - cacert: The TLS certificate + host: RabbitMQ host + port: RabbitMQ port + vhost: RabbitMQ virtual host + user: RabbitMQ username + password: RabbitMQ password for username + cert: The TLS certificate rmq_queue: The RMQ queue to listen to. """ - def __init__(self, db_dir, store, dest_dir, stage_dir, db_type, credentials, cacert, rmq_queue): + def __init__(self, db_dir, store, dest_dir, stage_dir, db_type, host, port, vhost, user, password, cert, rmq_queue): """ Initialize a RMQPipeline that will write data to the 'dest_dir' and optionally publish these files to the kosh-store 'store' by using the stage_dir as an intermediate directory. """ super().__init__(db_dir, store, dest_dir, stage_dir, db_type) - self._credentials = Path(credentials) - self._cacert = Path(cacert) + self._host = host + self._port = port + self._vhost = vhost + self._user = user + self._password = password + self._cert = Path(cert) self._rmq_queue = rmq_queue - def get_load_task(self, o_queue): + def get_load_task(self, o_queue, policy): """ Return a Task that loads data from the filesystem Args: o_queue: The queue the load task will push read data. - Returns: An RMQLoaderTask instance reading data from the filesystem and forwarding the values to the o_queue. - """ - return RMQLoaderTask(o_queue, self._credentials, self._cacert, self._rmq_queue) + Returns: An RMQLoaderTask instance reading data from the + filesystem and forwarding the values to the o_queue. + """ + return RMQLoaderTask( + o_queue, + self._host, + self._port, + self._vhost, + self._user, + self._password, + self._cert, + self._rmq_queue, + policy, + ) @staticmethod def add_cli_args(parser): @@ -823,18 +757,38 @@ def from_cli(cls, args): """ Create RMQPipeline from the user provided CLI. """ - print("Creating database from here", args.persistent_db_path) + + # TODO: implement an interface so users can plug any parser for RMQ credentials + config = cls.parse_credentials(cls, args.creds) + host = config["service-host"] + port = config["service-port"] + vhost = config["rabbitmq-vhost"] + user = config["rabbitmq-user"] + password = config["rabbitmq-password"] + return cls( args.persistent_db_path, args.store, args.dest_dir, args.stage_dir, args.db_type, - args.creds, + host, + port, + vhost, + user, + password, args.cert, args.queue, ) + @staticmethod + def parse_credentials(self, json_file: str) -> dict: + """Internal method to parse the credentials file""" + data = {} + with open(json_file, "r") as f: + data = json.load(f) + return data + def get_pipeline(src_mechanism="fs"): """ @@ -845,8 +799,8 @@ def get_pipeline(src_mechanism="fs"): Returns: A Pipeline class to start the stage AMS service """ - PipeMechanisms = {"fs": FSPipeline, "network": RMQPipeline} - if src_mechanism not in PipeMechanisms.keys(): + pipe_mechanisms = {"fs": FSPipeline, "network": RMQPipeline} + if src_mechanism not in pipe_mechanisms.keys(): raise RuntimeError(f"Pipeline {src_mechanism} storing mechanism does not exist") - return PipeMechanisms[src_mechanism] + return pipe_mechanisms[src_mechanism] diff --git a/src/AMSWorkflow/ams_wf/AMSBroker.py b/src/AMSWorkflow/ams_wf/AMSBroker.py index ff75874c..57d65cc3 100644 --- a/src/AMSWorkflow/ams_wf/AMSBroker.py +++ b/src/AMSWorkflow/ams_wf/AMSBroker.py @@ -10,7 +10,7 @@ import os import sys -from ams.rmq import RMQClient +from ams.rmq import BlockingClient def main(): @@ -68,7 +68,7 @@ def main(): user = config["rabbitmq-user"] password = config["rabbitmq-password"] - with RMQClient(host, port, vhost, user, password, args.certificate) as client: + with BlockingClient(host, port, vhost, user, password, args.certificate) as client: with client.connect(args.queue) as channel: channel.send(args.msg_send) diff --git a/src/AMSWorkflow/ams_wf/AMSDBStage.py b/src/AMSWorkflow/ams_wf/AMSDBStage.py index 9657d37c..e8bb3ba4 100644 --- a/src/AMSWorkflow/ams_wf/AMSDBStage.py +++ b/src/AMSWorkflow/ams_wf/AMSDBStage.py @@ -7,15 +7,14 @@ import time from ams.loader import load_class +from ams.monitor import AMSMonitor from ams.stage import get_pipeline -import sys - def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="AMS Stage mechanism. The mechanism moves data to the file-system and optionally registers them in a kosh store", + description="AMS Stage mechanism. The mechanism moves data to the file-system and optionally registers them in a Kosh store", ) parser.add_argument( @@ -32,7 +31,7 @@ def main(): default="process", ) - parser.add_argument("--mechansism", "-m", dest="mechanism", choices=["fs", "network"], default="fs") + parser.add_argument("--mechanism", "-m", dest="mechanism", choices=["fs", "network"], default="fs") args, extras = parser.parse_known_args() @@ -80,6 +79,9 @@ def main(): pipeline.execute(args.policy) end = time.time() print(f"End to End time spend : {end - start}") + print(f"{AMSMonitor.info()}") + # Output profiling output to JSON (just as an example) + AMSMonitor.json("ams_monitor.json") if __name__ == "__main__":