Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MonadCatch instance for STM #13

1 change: 1 addition & 0 deletions io-classes/src/Control/Monad/Class/MonadThrow.hs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ instance MonadEvaluate IO where
instance MonadThrow STM where
throwIO = STM.throwSTM


instance MonadCatch STM where
catch = STM.catchSTM

Expand Down
92 changes: 85 additions & 7 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
-- incomplete uni patterns in 'schedule' (when interpreting 'StmTxCommitted')
-- and 'reschedule'.
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}

module Control.Monad.IOSim.Internal
( IOSim (..)
Expand Down Expand Up @@ -71,9 +73,7 @@ import qualified Deque.Strict as Deque
import GHC.Exts (fromList)

import Control.Exception (NonTermination (..), assert, throw)
import Control.Monad (join)

import Control.Monad (when)
import Control.Monad (join, when)
import Control.Monad.ST.Lazy
import Control.Monad.ST.Lazy.Unsafe (unsafeIOToST, unsafeInterleaveST)
import Data.STRef.Lazy
Expand Down Expand Up @@ -828,6 +828,35 @@ runSimTraceST mainAction = schedule mainThread initialState
}


data StmControl s a where
StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a


-- Unwind the STM control stack till the matching exception is found
unwindControlStmStack :: forall s a.
SomeException
-> StmControl s a
-> Either Bool
( StmControl s a
, [Map TVarId (SomeTVar s)]
)
unwindControlStmStack e (StmControl _ frame) = unwindFrame [] frame

where
unwindFrame :: forall s' b. [Map TVarId (SomeTVar s')] -> StmStack s' b a -> Either Bool (StmControl s' a, [Map TVarId (SomeTVar s')])
unwindFrame _ AtomicallyFrame = Left True
unwindFrame ws (OrElseLeftFrame _ _ w _ _ ctl) = unwindFrame (w:ws) ctl
unwindFrame ws (OrElseRightFrame _ w _ _ ctl) = unwindFrame (w:ws) ctl
unwindFrame ws (CatchHandlerStmFrame _ _w _ _ ctl) = unwindFrame ws ctl -- Should not happen
unwindFrame ws (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) =
case fromException e of
-- Continue to unwind till we find a handler which can handle this exception.
Nothing -> unwindFrame (writtenOuter:ws) ctl
Just e' ->
let action' = handler e'
ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl
in Right $ (StmControl action' ctl', reverse ws)

--
-- Executing STM Transactions
--
Expand Down Expand Up @@ -910,11 +939,47 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- Continue with the k continuation
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
-- Successful main catch action
-- Merge allocations with outer sequence
!_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar)
(Map.intersection written writtenOuter)
-- Merge the written set of the inner with the outer
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
-- Skip the orElse right hand and continue with the k continuation
go ctl' read written' writtenSeq' createdOuterSeq nextVid (k x)

CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
-- Undo all written tvars
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (k x)

ThrowStm e -> {-# SCC "execAtomically.go.ThrowStm" #-} do
revertThem written
case unwindControlStmStack e (StmControl action ctl) of

-- Unwind to the nearest matching exception
Right (StmControl action' ctl', ws) -> do
mapM_ revertThem ws
go ctl' read written writtenSeq createdSeq nextVid action'

-- Abort if no matching exception is found
Left{} ->
k0 $ StmTxAborted [] (toException e)

-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
where
revertThem x =
traverse_ (\(SomeTVar tvar) -> revertTVar tvar) x

CatchStm act handler k ->
{-# SCC "execAtomically.go.CatchStm" #-} do
let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid act

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
Expand All @@ -941,6 +1006,19 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- This is XSTM3 test case from the STM paper.
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchHandlerStmFrame" #-} do
-- Undo all written tvars
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
-- Execute the left side in a new frame with an empty written set
Expand Down
42 changes: 42 additions & 0 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
-- Catch with continuation
CatchStm :: Exception e => StmA s a -> (e -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -228,6 +230,8 @@ instance Monad (IOSim s) where
fail = Fail.fail
#endif



instance Semigroup a => Semigroup (IOSim s a) where
(<>) = liftA2 (<>)

Expand All @@ -238,6 +242,8 @@ instance Monoid a => Monoid (IOSim s a) where
mappend = liftA2 mappend
#endif



instance Fail.MonadFail (IOSim s) where
fail msg = IOSim $ oneShot $ \_ -> Throw (toException (IO.Error.userError msg))

Expand Down Expand Up @@ -273,6 +279,8 @@ instance Monad (STM s) where
fail = Fail.fail
#endif



instance Fail.MonadFail (STM s) where
fail msg = STM $ oneShot $ \_ -> ThrowStm (toException (ErrorCall msg))

Expand Down Expand Up @@ -313,6 +321,23 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO

instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . handler) k

-- Default implmentation uses mask. For STM, mask is not necessary.
generalBracket acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance Exceptions.MonadCatch (STM s) where

catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -853,6 +878,23 @@ data StmStack s b a where
-> StmStack s b c
-> StmStack s a c

-- | Executing in the context of the /action/ part of the 'catch'
CatchStmFrame :: Exception e
=> (e -> StmA s a) -- exception handler
-> (a -> StmA s b) -- subsequent continuation
-> Map TVarId (SomeTVar s) -- saved written vars set
-> [SomeTVar s] -- saved written vars list
-> [SomeTVar s] -- created vars list (allocations)
-> StmStack s b c
-> StmStack s a c

-- | A continuation frame
CatchHandlerStmFrame :: (b -> StmA s c) -- subsequent continuation
-> Map TVarId (SomeTVar s) -- saved written vars set
-> [SomeTVar s] -- saved written vars list
-> [SomeTVar s] -- created vars list (allocations)
-> !(StmStack s c a)
-> StmStack s b a
---
--- Schedules
---
Expand Down
80 changes: 76 additions & 4 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,32 @@ controlSimTraceST limit control mainAction =
}


data StmControl s a where
StmControl :: StmA s b -> !(StmStack s b a) -> StmControl s a


-- Unwind the STM control stack till the matching exception is found
unwindControlStmStack :: forall s a.
SomeException
-> StmControl s a
-> Either Bool (StmControl s a)
unwindControlStmStack e (StmControl _ frame) = unwindFrame frame

where
unwindFrame :: forall s' b. StmStack s' b a -> Either Bool (StmControl s' a)
unwindFrame AtomicallyFrame = Left True
unwindFrame (OrElseLeftFrame _ _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (OrElseRightFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchHandlerStmFrame _ _ _ _ ctl) = unwindFrame ctl
unwindFrame (CatchStmFrame handler k writtenOuter writtenOuterSeq createdOuterSeq ctl) =
case fromException e of
-- Continue to unwind till we find a handler which can handle this exception.
Nothing -> unwindFrame ctl
Just e' ->
let action' = handler e'
ctl' = CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl
in Right $ StmControl action' ctl'

--
-- Executing STM Transactions
--
Expand Down Expand Up @@ -1121,11 +1147,44 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
-- Continue with the k continuation
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchStmFrame _handler k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)

CatchHandlerStmFrame k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> do
!_ <- traverse_ (\(SomeTVar tvar) -> commitTVar tvar)
(Map.intersection written writtenOuter)
let written' = Map.union written writtenOuter
writtenSeq' = filter (\(SomeTVar tvar) ->
tvarId tvar `Map.notMember` writtenOuter)
writtenSeq
++ writtenOuterSeq
createdSeq' = createdSeq ++ createdOuterSeq
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)



ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
{-# SCC "execAtomically.go.ThrowStm" #-}

let abort = do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)

in case unwindControlStmStack e (StmControl action ctl) of
Left _ -> abort
Right (StmControl action' ctl') -> go ctl' read written writtenSeq createdSeq nextVid action'

CatchStm act handler k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
let ctl' = CatchStmFrame handler k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid act

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
Expand All @@ -1152,6 +1211,19 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchStmFrame _handler _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchStmFrame" #-} do
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

CatchHandlerStmFrame _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.catchHandlerStmFrame" #-} do
-- Revert all the TVar writes within this catch action branch
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry


OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
-- Execute the left side in a new frame with an empty written set
Expand Down
5 changes: 4 additions & 1 deletion io-sim/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ module Main (main) where
import Test.Tasty

import qualified Test.IOSim (tests)
import qualified Test.STM (tests)

main :: IO ()
main = defaultMain tests

tests :: TestTree
tests =
testGroup "IO Sim"
[ Test.IOSim.tests
[
Test.IOSim.tests
, Test.STM.tests
]
8 changes: 4 additions & 4 deletions io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.IOSim

import Test.STM
import Test.STM hiding (tests)

import Test.QuickCheck
import Test.Tasty
Expand Down Expand Up @@ -134,8 +134,8 @@ tests =
, testProperty "16" unit_async_16
]
, testGroup "STM reference semantics"
[ testProperty "Reference vs IO" prop_stm_referenceIO
, testProperty "Reference vs Sim" prop_stm_referenceSim
[ testProperty "Reference vs IO" (withMaxSuccess 10000 prop_stm_referenceIO)
, testProperty "Reference vs Sim" (withMaxSuccess 10000 prop_stm_referenceSim)
]
, testGroup "MonadFix instance"
[ testProperty "purity" prop_mfix_purity
Expand Down Expand Up @@ -1049,7 +1049,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m, LazySTM.MonadSTM m, MonadCatch (LazySTM.STM m))
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
Loading