@@ -73,6 +73,8 @@ def __init__(
73
73
self ._ps_pod_name_to_id = {}
74
74
self ._relaunch_deleted_live_ps = True
75
75
76
+ self ._failed_pods = []
77
+
76
78
self ._k8s_client = k8s .Client (event_callback = self ._event_cb , ** kwargs )
77
79
self ._ps_addrs = self ._get_addrs (
78
80
self ._num_ps , self ._k8s_client .get_ps_service_address
@@ -218,10 +220,18 @@ def _event_cb(self, event):
218
220
worker_id = - 1
219
221
ps_id = - 1
220
222
with self ._lock :
223
+ if pod_name in self ._failed_pods :
224
+ return
221
225
if pod_name in self ._worker_pod_name_to_id :
222
226
worker_id = self ._worker_pod_name_to_id .get (pod_name )
223
227
self ._worker_pods_phase [worker_id ] = (pod_name , phase )
224
- if evt_type == "DELETED" :
228
+ # Workaround for memory leak issues in tf eager mode.
229
+ # A pod may fail due to OOM from tf eager mode memory leak.
230
+ failed_pod = False
231
+ if evt_type == "MODIFIED" and phase == "Failed" :
232
+ self ._failed_pods .append (pod_name )
233
+ failed_pod = True
234
+ if evt_type == "DELETED" or failed_pod :
225
235
del self ._worker_pods_phase [worker_id ]
226
236
del self ._worker_pod_name_to_id [pod_name ]
227
237
self ._task_d .recover_tasks (worker_id )
@@ -235,7 +245,13 @@ def _event_cb(self, event):
235
245
elif pod_name in self ._ps_pod_name_to_id :
236
246
ps_id = self ._ps_pod_name_to_id .get (pod_name )
237
247
self ._ps_pods_phase [ps_id ] = (pod_name , phase )
238
- if evt_type == "DELETED" :
248
+ # Workaround for memory leak issues in tf eager mode.
249
+ # A pod may fail due to OOM from tf eager mode memory leak.
250
+ failed_pod = False
251
+ if evt_type == "MODIFIED" and phase == "Failed" :
252
+ self ._failed_pods .append (pod_name )
253
+ failed_pod = True
254
+ if evt_type == "DELETED" or failed_pod :
239
255
del self ._ps_pods_phase [ps_id ]
240
256
del self ._ps_pod_name_to_id [pod_name ]
241
257
relaunch_ps = self ._relaunch_deleted_live_ps
0 commit comments