-
Notifications
You must be signed in to change notification settings - Fork 115
ElasticDL Memory Leak Debug
In ELasticDL, the instances including Master, Worker, and ParameterServer (PS) will be killed because of OOM due to memory leaks. In this document, we will summarize the problems of memory leaks we found when training a model using ElasticDL.
In ElasticDL, each worker executes evaluation tasks and reports the model outputs and corresponding labels to the master by GRPC. Then, the master calls tf.kreas.metrics.update_state
to calculate the metrics. Generally, grpc.server
uses multi-threads to process the GRPC request from workers and each thread will call update_state
for evaluation tasks. Memory leaks when executing update_state
using multi-threading.
The resource configuration:
--master_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
--worker_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
--ps_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
- Use multi-threads with max_worker=64 in the master
grpc.server
and train a deepFM model in the model zoo.
def _create_master_service(self, args):
self.logger.info("Creating master service")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=64),
- Use a single thread with max_worker=1 in the master
grpc.server
and train a deepFM model in the model zoo.
def _create_master_service(self, args):
self.logger.info("Creating master service")
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
The tf.keras.metric.update_states
executed using multi-threading will lead to a memory leak. The problem has been submitted to Tensorflow issue 35044
We find that the used memory is proportional to the size of outputs and labels which are used for tf.keras.metric.update_states
. So, to mitigate the memory leak, we can split outputs and labels to small chunks and execute tf.keras.metric.update_states
by those small chunks.
The GRPC server in the PS instances using multi-threading to receive gradients reported by the workers. Each thread will call opt.apply_gradient
to update the variables on the PS instance. Like tf.keras.metrics.update_state
in the master, executing opt.apply_gradient
using multi-threading also results in a memory leak.
The resource configuration:
--master_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
--worker_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
--ps_resource_request="cpu=1,memory=1024Mi,ephemeral-storage=1024Mi" \
- Use multi-threads with max_worker=64 in the PS
grpc.server
and train a deepFM model in the model zoo.
def prepare(self):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=64),
options=[
("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
(
"grpc.max_receive_message_length",
GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
),
],
)
- Use a single thread with max_worker=1 in the master
grpc.server
and train a deepFM model in the model zoo.
def prepare(self):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
options=[
("grpc.max_send_message_length", GRPC.MAX_SEND_MESSAGE_LENGTH),
(
"grpc.max_receive_message_length",
GRPC.MAX_RECEIVE_MESSAGE_LENGTH,
),
],
)
The cause is the same as tf.keras.metric.update_state
in the master.
By experiments, we find that the used memory is proportional to the threads number used by GRPC server. So we can reduce the threads number for GRPC server in each PS instance.
The detail is in tensorflow issue 35010.
tf.py_function
may create a static graph when we use eager execution to call tf.py_function
each time and the static graph can not be released after calling.
ElasticDL always uses eager execution to train a Keras model, so the input of an ElasticDL embedding layer is an EagerTensor
when executing the training loop in ElasticDL. So, we can directly call lookup_embedding
when inputs are EagerTensor.
2. Memory leaks when using tf.strings.split
in map_func
for tf.data.Dataset.map
with eager execution.
If we use tf.strings.split
in map_func to process each element in tf.data.Dataset, the used memory grows when we iterate the dataset and the used memory is not freed after iteration. What's more, the used memory continues to grow greatly if we repeatedly create the same tf.data.Dataset instance. However, the used memory keeps stable if we use tf.py_function to implement the split function.
import psutil
import tensorflow as tf
import pandas as pd
import numpy as np
FEATURE_COUNT = 400
# mock feature names and feature data
def gen_feature_names(feature_count):
feature_names = []
for i in range(feature_count):
feature_names.append("f{}".format(i))
return feature_names
def gen_samples(feature_names, sample_count=5000):
samples = []
for _ in range(sample_count):
feature_str = ""
for name in feature_names:
feature_str += "{};".format(np.random.random())
feature_str += str(np.random.randint(0,2))
yield feature_str
def dataset_fn(dataset):
def _py_parse_data(record):
record = record.numpy()
feature_labels = bytes.decode(record).split(";")
return feature_labels
def _parse_data(*record):
feature_values = record[0:-1]
features = {}
for i,feature_name in enumerate(FEATURE_NAMES):
features[feature_name] = feature_values[i]
label = tf.strings.to_number(record[-1], tf.int64)
return features, label
tout = [tf.string] * FEATURE_COUNT
tout.append(tf.string)
dataset = dataset.map(
lambda record: tf.py_function(
_py_parse_data,
[record],
tout
)
)
dataset = dataset.map(_parse_data)
dataset = dataset.shuffle(buffer_size=100)
return dataset
def dataset_fn_using_split(dataset):
def _parse_data(record):
feature_label = tf.strings.split([record], sep=';')[0]
feature_values = feature_label[0:-1]
features = {}
for i, feature_name in enumerate(FEATURE_NAMES):
features[feature_name] = feature_values[i]
label = feature_label[-1]
return features, label
dataset = dataset.map(_parse_data)
dataset = dataset.shuffle(buffer_size=100)
return dataset
def create_dataset(feature_names, using_split=True):
dataset = tf.data.Dataset.from_generator(
lambda : gen_samples(feature_names), tf.string
)
if using_split:
dataset = dataset_fn_using_split(dataset)
else:
dataset = dataset_fn(dataset)
dataset = dataset.batch(512)
return dataset
def view_used_mem():
used_mem = psutil.virtual_memory().used
print('used memory: {} Mb'.format(used_mem / 1024 / 1024))
FEATURE_NAMES = gen_feature_names(FEATURE_COUNT)
# Test used memory by using tf.strings.split in `map_func`
FEATURE_NAMES = gen_feature_names(FEATURE_COUNT)
start_time = time.time()
for i in range(4):
print("loop {}".format(i))
view_used_mem()
dataset = create_dataset(FEATURE_NAMES, using_split=True)
for batch in dataset:
pass
print("Consume time : {}".format(time.time() - start_time))
print("end")
view_used_mem()
# Test used memory by using `tf.py_function`
start_time = time.time()
for i in range(4):
print("loop {}".format(i))
view_used_mem()
dataset = create_dataset(FEATURE_NAMES, using_split=False)
for batch in dataset:
pass
print("Consume time : {}".format(time.time() - start_time))
print("end")
view_used_mem()
From the experiment, we can make sure that tf.strings.split
leads to the memory leak. The issue 35152 has been submitted to Tensorflow.
As seen in the above experiment, we can use tf.py_function
to implement the same logic as tf.strings.split
. However, the speed of tf.py_function
is much slower than tf.strings.split
.