Skip to content

Commit 020c928

Browse files
authored
Split labels and preds to small chunks and update metrics with those chunks (#1560)
* split labels and preds to small chunks and update metrics with those chunks * remove annotation code lines * update code by pre-commit * update code by pre-commit * fix docstring
1 parent c37c1f7 commit 020c928

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

elasticdl/python/master/evaluation_service.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from threading import Thread
44

5+
import numpy as np
56
from tensorflow.python.keras import metrics as metrics_module
67

78
from elasticdl.proto import elasticdl_pb2
@@ -74,7 +75,9 @@ def report_evaluation_metrics(self, model_outputs, labels):
7475
continue
7576
outputs = tensor_pb_to_ndarray(tensor_pb)
7677
for metric_inst in metrics.values():
77-
metric_inst.update_state(labels, outputs)
78+
self._update_metric_by_small_chunk(
79+
metric_inst, labels, outputs
80+
)
7881

7982
def get_evaluation_summary(self):
8083
if self._model_have_multiple_outputs:
@@ -92,6 +95,28 @@ def get_evaluation_summary(self):
9295
].items()
9396
}
9497

98+
def reset_metric_states(self):
99+
"""Resets all of the metric state variables."""
100+
for metrics in self._metrics_dict.values():
101+
for metric_inst in metrics.values():
102+
metric_inst.reset_states()
103+
104+
@staticmethod
105+
def _update_metric_by_small_chunk(
106+
metric, labels, outputs, chunk_length=500
107+
):
108+
"""The metric updates state in a thread launched by grpc. The memory will
109+
increase greatly if we update the metric with large size outputs. So
110+
we split the outputs and labels to small chunks then update the metric
111+
with those small chunks. The [issue 35044](https://github.com/
112+
tensorflow/tensorflow/issues/35044) has been submitted to tensorflow.
113+
"""
114+
chunk_boundaries = np.asarray(range(0, len(labels), chunk_length))
115+
label_chunks = np.array_split(labels, chunk_boundaries)
116+
output_chunks = np.array_split(outputs, chunk_boundaries)
117+
for label, output in zip(label_chunks, output_chunks):
118+
metric.update_state(label, output)
119+
95120

96121
class _EvaluationTrigger(Thread):
97122
"""A trigger which generates evaluation tasks periodically"""
@@ -201,9 +226,14 @@ def try_to_create_new_job(self):
201226
elasticdl_pb2.EVALUATION, checkpoint_version
202227
)
203228
task_count = len(self._task_d._eval_todo)
204-
self._eval_job = _EvaluationJob(
205-
self._eval_metrics_fn(), checkpoint_version, task_count
206-
)
229+
if self._eval_job is None:
230+
self._eval_job = _EvaluationJob(
231+
self._eval_metrics_fn(), checkpoint_version, task_count
232+
)
233+
else:
234+
self._eval_job.model_version = checkpoint_version
235+
self._eval_job._total_tasks = task_count
236+
self._eval_job.reset_metric_states()
207237
return True
208238
return False
209239

@@ -227,7 +257,10 @@ def add_evaluation_task_if_needed(self, master_locking, model_version):
227257
def report_evaluation_metrics(self, model_outputs, labels):
228258
if self._eval_job is None:
229259
return False
230-
return self._eval_job.report_evaluation_metrics(model_outputs, labels)
260+
with self._lock:
261+
return self._eval_job.report_evaluation_metrics(
262+
model_outputs, labels
263+
)
231264

232265
def complete_task(self):
233266
self._eval_job.complete_task()

elasticdl/python/tests/evaluation_service_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,18 @@ def testNeedEvaluation(self):
179179
evaluation_service._eval_checkpoint_versions, [20, 30]
180180
)
181181

182+
def test_update_metric_by_small_chunks(self):
183+
labels = np.random.randint(0, 2, 1234)
184+
preds = np.random.random(1234)
185+
auc = tf.keras.metrics.AUC()
186+
auc.update_state(labels, preds)
187+
auc_value_0 = auc.result()
188+
189+
auc.reset_states()
190+
_EvaluationJob._update_metric_by_small_chunk(auc, labels, preds)
191+
auc_value_1 = auc.result()
192+
self.assertEquals(auc_value_0, auc_value_1)
193+
182194

183195
if __name__ == "__main__":
184196
unittest.main()

0 commit comments

Comments
 (0)