Skip to content

Commit

Permalink
Only shuffle 20% of the time
Browse files Browse the repository at this point in the history
  • Loading branch information
bolt12 committed Dec 18, 2023
1 parent 6a781b9 commit 9fe5c5e
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ module Control.Monad.IOSim.Internal

import Prelude hiding (read)

import Data.Deque.Strict (Deque)
import qualified Data.Deque.Strict as Deque
import Data.Dynamic
import Data.Foldable (foldlM, toList, traverse_)
import qualified Data.List as List
Expand All @@ -60,8 +62,6 @@ import qualified Data.OrdPSQ as PSQ
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Time (UTCTime (..), fromGregorian)
import Data.Deque.Strict (Deque)
import qualified Data.Deque.Strict as Deque

import Control.Exception (NonTermination (..), assert, throw)
import Control.Monad (join, when)
Expand All @@ -76,13 +76,16 @@ import Control.Monad.Class.MonadSTM hiding (STM)
import Control.Monad.Class.MonadSTM.Internal (TMVarDefault (TMVar))
import Control.Monad.Class.MonadThrow hiding (getMaskingState)
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer.SI (TimeoutState (..), DiffTime, diffTimeToMicrosecondsAsInt, microsecondsAsIntToDiffTime)
import Control.Monad.Class.MonadTimer.SI (DiffTime, TimeoutState (..),
diffTimeToMicrosecondsAsInt, microsecondsAsIntToDiffTime)

import Control.Monad.IOSim.InternalTypes
import Control.Monad.IOSim.Types hiding (SimEvent (SimPOREvent),
Trace (SimPORTrace))
import Control.Monad.IOSim.Types (SimEvent)
import System.Random (StdGen, randomR, split)
import Data.Bifunctor (first)
import Data.Ord (comparing)
import System.Random (StdGen, randomR, split)

--
-- Simulation interpreter
Expand Down Expand Up @@ -849,31 +852,47 @@ reschedule !simstate@SimState{ threads, timers, curTime = time } =
timeoutSTMAction TimerTimeout{} = return ()

unblockThreads :: Bool -> [IOSimThreadId] -> SimState s a -> ([IOSimThreadId], SimState s a)
unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
unblockThreads !onlySTM !wakeup simstate@SimState {runqueue, threads, stdGen} =
-- To preserve our invariants (that threadBlocked is correct)
-- we update the runqueue and threads together here
(unblocked, simstate {
runqueue = Deque.fromList (shuffledRunqueue ++ rest),
runqueue = runqueue <> Deque.fromList unblocked,
threads = threads',
stdGen = stdGen''
})
where
!(shuffledRunqueue, stdGen'') = fisherYatesShuffle stdGen' toShuffle
!((toShuffle, rest), stdGen') =
let runqueueList = Deque.toList $ runqueue <> Deque.fromList unblocked
runqueueListLength = max 1 (length runqueueList)
(ix, newGen) = randomR (0, runqueueListLength `div` 2) stdGen
in (splitAt ix runqueueList, newGen)
-- can only unblock if the thread exists and is blocked (not running)
!unblocked = [ tid
| tid <- wakeup
, case Map.lookup tid threads of
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
-> True
Just Thread { threadStatus = ThreadBlocked _ }
-> not onlySTM
_ -> False
]
!blockedOnOther = [ (tid, ix)
| (tid, ix) <- zip wakeup [0 :: Int ..]
, case Map.lookup tid threads of
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
-> False
Just Thread { threadStatus = ThreadBlocked _ }
-> not onlySTM
_ -> False
]

!blockedOnSTM = [ (tid, ix)
| (tid, ix) <- zip wakeup [0 :: Int ..]
, case Map.lookup tid threads of
Just Thread { threadStatus = ThreadBlocked BlockedOnSTM }
-> True
_ -> False
]

mergeByIndex :: Ord a => [(b, a)] -> [(b, a)] -> [b]
mergeByIndex a b = map fst $ List.sortBy (comparing snd) (a ++ b)

-- Shuffle only 1/5th of the time
(shouldShuffle, !stdGen') =
first (== 0) $ randomR (0 :: Int, 5) stdGen

(!shuffledBlockedOnSTM, !stdGen'')
| shouldShuffle = fisherYatesShuffle stdGen' blockedOnSTM
| otherwise = (blockedOnSTM, stdGen')

!unblocked = mergeByIndex blockedOnOther shuffledBlockedOnSTM

-- and in which case we mark them as now running
!threads' = List.foldl'
(flip (Map.adjust (\t -> t { threadStatus = ThreadRunning })))
Expand All @@ -889,8 +908,8 @@ unblockThreads !onlySTM !wakeup !simstate@SimState {runqueue, threads, stdGen} =
where
go 0 lst g = (lst, g)
go n lst g = let (k, newGen) = randomR (0, n) g
(x:xs) = drop k lst
swapped = take k lst ++ [lst !! n] ++ drop (k + 1) lst
(x:xs) = drop k lst
swapped = take k lst ++ [lst !! n] ++ drop (k + 1) lst
in go (n - 1) (take n swapped ++ [x] ++ drop n xs) newGen

-- | This function receives a list of TimerTimeout values that represent threads
Expand Down

0 comments on commit 9fe5c5e

Please sign in to comment.