2
2
import time
3
3
from threading import Thread
4
4
5
+ import numpy as np
5
6
from tensorflow .python .keras import metrics as metrics_module
6
7
7
8
from elasticdl .proto import elasticdl_pb2
@@ -74,7 +75,9 @@ def report_evaluation_metrics(self, model_outputs, labels):
74
75
continue
75
76
outputs = tensor_pb_to_ndarray (tensor_pb )
76
77
for metric_inst in metrics .values ():
77
- metric_inst .update_state (labels , outputs )
78
+ self ._update_metric_by_small_chunk (
79
+ metric_inst , labels , outputs
80
+ )
78
81
79
82
def get_evaluation_summary (self ):
80
83
if self ._model_have_multiple_outputs :
@@ -92,6 +95,28 @@ def get_evaluation_summary(self):
92
95
].items ()
93
96
}
94
97
98
+ def reset_metric_states (self ):
99
+ """Resets all of the metric state variables."""
100
+ for metrics in self ._metrics_dict .values ():
101
+ for metric_inst in metrics .values ():
102
+ metric_inst .reset_states ()
103
+
104
+ @staticmethod
105
+ def _update_metric_by_small_chunk (
106
+ metric , labels , outputs , chunk_length = 500
107
+ ):
108
+ """The metric updates state in a thread launched by grpc. The memory will
109
+ increase greatly if we update the metric with large size outputs. So
110
+ we split the outputs and labels to small chunks then update the metric
111
+ with those small chunks. The [issue 35044](https://github.com/
112
+ tensorflow/tensorflow/issues/35044) has been submitted to tensorflow.
113
+ """
114
+ chunk_boundaries = np .asarray (range (0 , len (labels ), chunk_length ))
115
+ label_chunks = np .array_split (labels , chunk_boundaries )
116
+ output_chunks = np .array_split (outputs , chunk_boundaries )
117
+ for label , output in zip (label_chunks , output_chunks ):
118
+ metric .update_state (label , output )
119
+
95
120
96
121
class _EvaluationTrigger (Thread ):
97
122
"""A trigger which generates evaluation tasks periodically"""
@@ -201,9 +226,14 @@ def try_to_create_new_job(self):
201
226
elasticdl_pb2 .EVALUATION , checkpoint_version
202
227
)
203
228
task_count = len (self ._task_d ._eval_todo )
204
- self ._eval_job = _EvaluationJob (
205
- self ._eval_metrics_fn (), checkpoint_version , task_count
206
- )
229
+ if self ._eval_job is None :
230
+ self ._eval_job = _EvaluationJob (
231
+ self ._eval_metrics_fn (), checkpoint_version , task_count
232
+ )
233
+ else :
234
+ self ._eval_job .model_version = checkpoint_version
235
+ self ._eval_job ._total_tasks = task_count
236
+ self ._eval_job .reset_metric_states ()
207
237
return True
208
238
return False
209
239
@@ -227,7 +257,10 @@ def add_evaluation_task_if_needed(self, master_locking, model_version):
227
257
def report_evaluation_metrics (self , model_outputs , labels ):
228
258
if self ._eval_job is None :
229
259
return False
230
- return self ._eval_job .report_evaluation_metrics (model_outputs , labels )
260
+ with self ._lock :
261
+ return self ._eval_job .report_evaluation_metrics (
262
+ model_outputs , labels
263
+ )
231
264
232
265
def complete_task (self ):
233
266
self ._eval_job .complete_task ()
0 commit comments