Skip to content

Commit

Permalink
Add catch semantics to STM
Browse files Browse the repository at this point in the history
- Add support for Catch in IOSim and IOSimPOR
- Add support for Catch in Test/STM.hs
  • Loading branch information
yogeshsajanikar committed Nov 10, 2022
1 parent 8037962 commit a8da6e8
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 46 deletions.
36 changes: 31 additions & 5 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1148,19 +1148,45 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
-- Rollback `TVar`s written since catch handler was installed
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)

BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the catch handler with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a


Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! (Map.elems read)
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
Expand Down
6 changes: 2 additions & 4 deletions io-sim/src/Control/Monad/IOSim/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,10 @@ writeTBQueueDefault (TBQueue queue _size) a = do

isEmptyTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
isEmptyTBQueueDefault (TBQueue queue _size) = do
(xs, _, ys, _) <- readTVar queue
(xs, _, _, _) <- readTVar queue
case xs of
_:_ -> return False
[] -> case ys of
[] -> return True
_ -> return False
[] -> return True

isFullTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
isFullTBQueueDefault (TBQueue queue _size) = do
Expand Down
45 changes: 42 additions & 3 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -339,6 +340,31 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO


instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
where
-- Get a total handler from the given handler
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
fromHandler h e = case fromException e of
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
Just e' -> h e'

-- Masking is not required as STM actions are always run inside
-- `execAtomically` and behave as if masked. Also note that the default
-- implementation of `generalBracket` needs mask, and is part of `MonadThrow`.
generalBracket acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance Exceptions.MonadCatch (STM s) where
catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -867,9 +893,22 @@ data StmTxResult s a =
| StmTxAborted [SomeTVar s] SomeException


-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
-- can be a NoOp
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
-- | A branch indicates that an alternative statement is available in the current
-- context. For example, `OrElse` has two alternative statements, say "left"
-- and "right". While executing the left statement, `OrElseStmA` branch indicates
-- that the right branch is still available, in case the left statement fails.
data BranchStmA s a =
-- | `OrElse` statement with its 'right' alternative.
OrElseStmA (StmA s a)
-- | `CatchStm` statement with the 'catch' handler.
| CatchStmA (SomeException -> StmA s a)
-- | Unlike the other two branches, the no-op branch is not an explicit
-- part of the STM syntax. It simply indicates that there are no
-- alternative statements left to be executed. For example, when running
-- right alternative of the `OrElse` statement or when running the catch
-- handler of a `CatchStm` statement, there are no alternative statements
-- available. This case is represented by the no-op branch.
| NoOpStmA

data StmStack s b a where
-- | Executing in the context of a top level 'atomically'.
Expand Down
70 changes: 46 additions & 24 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,32 +1391,54 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)

BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the left side in a new frame with an empty written set
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
Expand Down
2 changes: 1 addition & 1 deletion io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
59 changes: 50 additions & 9 deletions io-sim/test/Test/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ data Term (t :: Type) where

Return :: Expr t -> Term t
Throw :: Expr a -> Term t
Catch :: Term t -> Term t -> Term t
Retry :: Term t

ReadTVar :: Name (TyVar t) -> Term t
Expand Down Expand Up @@ -296,7 +297,7 @@ deriving instance Show (NfTerm t)
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
--
-- Compare the implementation of this against the operational semantics in
-- Figure 4 in the paper. Note that @catch@ is not included.
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
--
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
evalTerm !env !heap !allocs term = case term of
Expand All @@ -309,6 +310,30 @@ evalTerm !env !heap !allocs term = case term of
where
e' = evalExpr env e

-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
Catch t1 t2 ->
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of

-- Rule XSTM1
-- M; heap, {} => return P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs'
NfReturn v -> (NfReturn v, heap', allocs <> allocs')

-- Rule XSTM2
-- M; heap, {} => throw P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2

-- Rule XSTM3
-- M; heap, {} => retry; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
NfRetry -> (NfRetry, heap, allocs)


Retry -> (NfRetry, heap, allocs)

-- Rule READ
Expand Down Expand Up @@ -437,7 +462,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =

-- | Execute an STM 'Term' in the 'STM' monad.
--
execTerm :: (MonadSTM m, MonadThrow (STM m))
execTerm :: (MonadSTM m, MonadCatch (STM m))
=> ExecEnv m
-> Term t
-> STM m (ExecValue m t)
Expand All @@ -451,6 +476,8 @@ execTerm env t =
let e' = execExpr env e
throwSTM =<< snapshotExecValue e'

Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2

Retry -> retry

ReadTVar n -> do
Expand Down Expand Up @@ -491,7 +518,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
(snapshotExecValue =<< readTVar v)

execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> Term t -> m TxResult
execAtomically t =
toTxResult <$> try (atomically action')
Expand Down Expand Up @@ -657,7 +684,7 @@ genTerm env tyrep =
Nothing)
]

binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]

bindTerm =
sized $ \sz -> do
Expand All @@ -671,10 +698,15 @@ genTerm env tyrep =
return (Bind t1 name t2)

orElseTerm =
sized $ \sz -> resize (sz `div` 2) $
scale (`div` 2) $
OrElse <$> genTerm env tyrep
<*> genTerm env tyrep

catchTerm =
scale (`div` 2) $
Catch <$> genTerm env tyrep
<*> genTerm env tyrep

genSomeExpr :: GenEnv -> Gen SomeExpr
genSomeExpr env =
oneof'
Expand Down Expand Up @@ -713,6 +745,8 @@ shrinkTerm t =
case t of
Return e -> [Return e' | e' <- shrinkExpr e]
Throw e -> [Throw e' | e' <- shrinkExpr e]
Catch t1 t2 -> [t1, t2]
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
Retry -> []
ReadTVar _ -> []

Expand All @@ -721,12 +755,10 @@ shrinkTerm t =
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]

Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

OrElse t1 t2 -> [t1, t2]
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

shrinkExpr :: Expr t -> [Expr t]
shrinkExpr ExprUnit = []
Expand All @@ -738,6 +770,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
freeNamesTerm :: Term t -> Set NameId
freeNamesTerm (Return e) = freeNamesExpr e
freeNamesTerm (Throw e) = freeNamesExpr e
-- The current generator of catch term ignores the argument passed to the
-- handler.
-- TODO: Correctly handle free names when the handler also binds a variable.
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
freeNamesTerm Retry = Set.empty
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
Expand Down Expand Up @@ -768,6 +804,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
termSize :: Term a -> Int
termSize Return{} = 1
termSize Throw{} = 1
termSize (Catch a b) = 1 + termSize a + termSize b
termSize Retry{} = 1
termSize ReadTVar{} = 1
termSize WriteTVar{} = 1
Expand All @@ -778,6 +815,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
termDepth :: Term a -> Int
termDepth Return{} = 1
termDepth Throw{} = 1
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
termDepth Retry{} = 1
termDepth ReadTVar{} = 1
termDepth WriteTVar{} = 1
Expand All @@ -790,6 +828,9 @@ showTerm p (Return e) = showParen (p > 10) $
showString "return " . showExpr 11 e
showTerm p (Throw e) = showParen (p > 10) $
showString "throwSTM " . showExpr 11 e
showTerm p (Catch t1 t2) = showParen (p > 9) $
showTerm 10 t1 . showString " `catch` "
. showTerm 10 t2
showTerm _ Retry = showString "retry"
showTerm p (ReadTVar n) = showParen (p > 10) $
showString "readTVar " . showName n
Expand Down

0 comments on commit a8da6e8

Please sign in to comment.