@@ -9,7 +9,7 @@ use std::pin::Pin;
99use  std:: ptr; 
1010use  std:: sync:: atomic:: AtomicUsize ; 
1111use  std:: sync:: atomic:: Ordering :: { Acquire ,  SeqCst } ; 
12- use  std:: sync:: { Arc ,  Mutex ,  Weak } ; 
12+ use  std:: sync:: { Arc ,  Mutex ,  MutexGuard ,   Weak } ; 
1313
1414/// Future for the [`shared`](super::FutureExt::shared) method. 
1515#[ must_use = "futures do nothing unless you `.await` or poll them" ]  
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
8181const  POLLING :  usize  = 1 ; 
8282const  COMPLETE :  usize  = 2 ; 
8383const  POISONED :  usize  = 3 ; 
84+ const  WOKEN_DURING_POLLING :  usize  = 4 ; 
8485
8586const  NULL_WAKER_KEY :  usize  = usize:: MAX ; 
8687
@@ -197,36 +198,47 @@ where
197198    } 
198199} 
199200
200- impl < Fut >  Inner < Fut > 
201- where 
202-     Fut :  Future , 
203-     Fut :: Output :  Clone , 
204- { 
205-     /// Registers the current task to receive a wakeup when we are awoken. 
206- fn  record_waker ( & self ,  waker_key :  & mut  usize ,  cx :  & mut  Context < ' _ > )  { 
207-         let  mut  wakers_guard = self . notifier . wakers . lock ( ) . unwrap ( ) ; 
208- 
209-         let  wakers_mut = wakers_guard. as_mut ( ) ; 
210- 
211-         let  wakers = match  wakers_mut { 
212-             Some ( wakers)  => wakers, 
213-             None  => return , 
214-         } ; 
215- 
216-         let  new_waker = cx. waker ( ) ; 
201+ /// Registers the current task to receive a wakeup when we are awoken. 
202+ fn  record_waker ( 
203+     wakers_guard :  & mut  MutexGuard < ' _ ,  Option < Slab < Option < Waker > > > > , 
204+     waker_key :  & mut  usize , 
205+     cx :  & mut  Context < ' _ > , 
206+ )  { 
207+     let  wakers = match  wakers_guard. as_mut ( )  { 
208+         Some ( wakers)  => wakers, 
209+         None  => return , 
210+     } ; 
211+ 
212+     let  new_waker = cx. waker ( ) ; 
213+ 
214+     if  * waker_key == NULL_WAKER_KEY  { 
215+         * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ; 
216+     }  else  { 
217+         match  wakers[ * waker_key]  { 
218+             Some ( ref  old_waker)  if  new_waker. will_wake ( old_waker)  => { } 
219+             // Could use clone_from here, but Waker doesn't specialize it. 
220+             ref  mut  slot => * slot = Some ( new_waker. clone ( ) ) , 
221+         } 
222+     } 
223+     debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ; 
224+ } 
217225
218-         if  * waker_key == NULL_WAKER_KEY  { 
219-             * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ; 
220-         }  else  { 
221-             match  wakers[ * waker_key]  { 
222-                 Some ( ref  old_waker)  if  new_waker. will_wake ( old_waker)  => { } 
223-                 // Could use clone_from here, but Waker doesn't specialize it. 
224-                 ref  mut  slot => * slot = Some ( new_waker. clone ( ) ) , 
226+ /// Wakes all tasks that are registered to be woken. 
227+ fn  wake_all ( waker_guard :  & mut  MutexGuard < ' _ ,  Option < Slab < Option < Waker > > > > )  { 
228+     if  let  Some ( wakers)  = waker_guard. as_mut ( )  { 
229+         for  ( _key,  opt_waker)  in  wakers { 
230+             if  let  Some ( waker)  = opt_waker. take ( )  { 
231+                 waker. wake ( ) ; 
225232            } 
226233        } 
227-         debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ; 
228234    } 
235+ } 
229236
237+ impl < Fut >  Inner < Fut > 
238+ where 
239+     Fut :  Future , 
240+     Fut :: Output :  Clone , 
241+ { 
230242    /// Safety: callers must first ensure that `inner.state` 
231243/// is `COMPLETE` 
232244unsafe  fn  take_or_clone_output ( self :  Arc < Self > )  -> Fut :: Output  { 
@@ -268,18 +280,22 @@ where
268280            return  unsafe  {  Poll :: Ready ( inner. take_or_clone_output ( ) )  } ; 
269281        } 
270282
271-         inner. record_waker ( & mut  this. waker_key ,  cx) ; 
283+         // Guard the state transition with mutex too 
284+         let  mut  wakers_guard = inner. notifier . wakers . lock ( ) . unwrap ( ) ; 
285+         record_waker ( & mut  wakers_guard,  & mut  this. waker_key ,  cx) ; 
272286
273-         match  inner
287+         let  prev =  inner
274288            . notifier 
275289            . state 
276290            . compare_exchange ( IDLE ,  POLLING ,  SeqCst ,  SeqCst ) 
277-             . unwrap_or_else ( |x| x) 
278-         { 
291+             . unwrap_or_else ( |x| x) ; 
292+         drop ( wakers_guard) ; 
293+ 
294+         match  prev { 
279295            IDLE  => { 
280296                // Lock acquired, fall through 
281297            } 
282-             POLLING  => { 
298+             POLLING  |  WOKEN_DURING_POLLING   => { 
283299                // Another task is currently polling, at this point we just want 
284300                // to ensure that the waker for this task is registered 
285301                this. inner  = Some ( inner) ; 
@@ -324,15 +340,21 @@ where
324340
325341            match  poll_result { 
326342                Poll :: Pending  => { 
327-                     if  inner. notifier . state . compare_exchange ( POLLING ,  IDLE ,  SeqCst ,  SeqCst ) . is_ok ( ) 
328-                     { 
329-                         // Success 
330-                         drop ( reset) ; 
331-                         this. inner  = Some ( inner) ; 
332-                         return  Poll :: Pending ; 
333-                     }  else  { 
334-                         unreachable ! ( ) 
343+                     match  inner. notifier . state . compare_exchange ( POLLING ,  IDLE ,  SeqCst ,  SeqCst )  { 
344+                         Ok ( POLLING )  => { }  // success 
345+                         Err ( WOKEN_DURING_POLLING )  => { 
346+                             // waker has been called inside future.poll, need to wake any new wakers registered 
347+                             let  mut  wakers = inner. notifier . wakers . lock ( ) . unwrap ( ) ; 
348+                             wake_all ( & mut  wakers) ; 
349+                             let  prev = inner. notifier . state . swap ( IDLE ,  SeqCst ) ; 
350+                             assert_eq ! ( prev,  WOKEN_DURING_POLLING ) ; 
351+                             drop ( wakers) ; 
352+                         } 
353+                         _ => unreachable ! ( ) , 
335354                    } 
355+                     drop ( reset) ; 
356+                     this. inner  = Some ( inner) ; 
357+                     return  Poll :: Pending ; 
336358                } 
337359                Poll :: Ready ( output)  => output, 
338360            } 
@@ -387,14 +409,9 @@ where
387409
388410impl  ArcWake  for  Notifier  { 
389411    fn  wake_by_ref ( arc_self :  & Arc < Self > )  { 
390-         let  wakers = & mut  * arc_self. wakers . lock ( ) . unwrap ( ) ; 
391-         if  let  Some ( wakers)  = wakers. as_mut ( )  { 
392-             for  ( _key,  opt_waker)  in  wakers { 
393-                 if  let  Some ( waker)  = opt_waker. take ( )  { 
394-                     waker. wake ( ) ; 
395-                 } 
396-             } 
397-         } 
412+         let  mut  wakers = arc_self. wakers . lock ( ) . unwrap ( ) ; 
413+         let  _ = arc_self. state . compare_exchange ( POLLING ,  WOKEN_DURING_POLLING ,  SeqCst ,  SeqCst ) ; 
414+         wake_all ( & mut  wakers) ; 
398415    } 
399416} 
400417
0 commit comments