Skip to content

Commit a8da6e8

Browse files
Add catch semantics to STM
- Add support for Catch in IOSim and IOSimPOR - Add support for Catch in Test/STM.hs
1 parent 8037962 commit a8da6e8

File tree

6 files changed

+172
-46
lines changed

6 files changed

+172
-46
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,19 +1148,45 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
11481148

11491149
ThrowStm e ->
11501150
{-# SCC "execAtomically.go.ThrowStm" #-} do
1151-
-- Revert all the TVar writes
1151+
-- Rollback `TVar`s written since catch handler was installed
11521152
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1153-
k0 $ StmTxAborted [] (toException e)
1153+
case ctl of
1154+
AtomicallyFrame -> do
1155+
k0 $ StmTxAborted (Map.elems read) (toException e)
1156+
1157+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1158+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1159+
-- Execute the left side in a new frame with an empty written set.
1160+
-- but preserve ones that were set prior to it, as specified in the
1161+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
1162+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1163+
go ctl'' read Map.empty [] [] nextVid (h e)
1164+
1165+
BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1166+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1167+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1168+
1169+
BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1170+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1171+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1172+
1173+
CatchStm a h k ->
1174+
{-# SCC "execAtomically.go.ThrowStm" #-} do
1175+
-- Execute the catch handler with an empty written set.
1176+
-- but preserve ones that were set prior to it, as specified in the
1177+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
1178+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
1179+
go ctl' read Map.empty [] [] nextVid a
1180+
11541181

11551182
Retry ->
1156-
{-# SCC "execAtomically.go.Retry" #-}
1157-
do
1183+
{-# SCC "execAtomically.go.Retry" #-} do
11581184
-- Always revert all the TVar writes for the retry
11591185
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
11601186
case ctl of
11611187
AtomicallyFrame -> do
11621188
-- Return vars read, so the thread can block on them
1163-
k0 $! StmTxBlocked $! (Map.elems read)
1189+
k0 $! StmTxBlocked $! Map.elems read
11641190

11651191
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
11661192
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do

io-sim/src/Control/Monad/IOSim/STM.hs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,10 @@ writeTBQueueDefault (TBQueue queue _size) a = do
171171

172172
isEmptyTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
173173
isEmptyTBQueueDefault (TBQueue queue _size) = do
174-
(xs, _, ys, _) <- readTVar queue
174+
(xs, _, _, _) <- readTVar queue
175175
case xs of
176176
_:_ -> return False
177-
[] -> case ys of
178-
[] -> return True
179-
_ -> return False
177+
[] -> return True
180178

181179
isFullTBQueueDefault :: MonadSTM m => TBQueueDefault m a -> STM m Bool
182180
isFullTBQueueDefault (TBQueue queue _size) = do

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ runSTM (STM k) = k ReturnStm
195195
data StmA s a where
196196
ReturnStm :: a -> StmA s a
197197
ThrowStm :: SomeException -> StmA s a
198+
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b
198199

199200
NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
200201
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
@@ -339,6 +340,31 @@ instance MonadThrow (STM s) where
339340
instance Exceptions.MonadThrow (STM s) where
340341
throwM = MonadThrow.throwIO
341342

343+
344+
instance MonadCatch (STM s) where
345+
346+
catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
347+
where
348+
-- Get a total handler from the given handler
349+
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
350+
fromHandler h e = case fromException e of
351+
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
352+
Just e' -> h e'
353+
354+
-- Masking is not required as STM actions are always run inside
355+
-- `execAtomically` and behave as if masked. Also note that the default
356+
-- implementation of `generalBracket` needs mask, and is part of `MonadThrow`.
357+
generalBracket acquire release use = do
358+
resource <- acquire
359+
b <- use resource `catch` \e -> do
360+
_ <- release resource (ExitCaseException e)
361+
throwIO e
362+
c <- release resource (ExitCaseSuccess b)
363+
return (b, c)
364+
365+
instance Exceptions.MonadCatch (STM s) where
366+
catch = MonadThrow.catch
367+
342368
instance MonadCatch (IOSim s) where
343369
catch action handler =
344370
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
@@ -867,9 +893,22 @@ data StmTxResult s a =
867893
| StmTxAborted [SomeTVar s] SomeException
868894

869895

870-
-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
871-
-- can be a NoOp
872-
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
896+
-- | A branch indicates that an alternative statement is available in the current
897+
-- context. For example, `OrElse` has two alternative statements, say "left"
898+
-- and "right". While executing the left statement, `OrElseStmA` branch indicates
899+
-- that the right branch is still available, in case the left statement fails.
900+
data BranchStmA s a =
901+
-- | `OrElse` statement with its 'right' alternative.
902+
OrElseStmA (StmA s a)
903+
-- | `CatchStm` statement with the 'catch' handler.
904+
| CatchStmA (SomeException -> StmA s a)
905+
-- | Unlike the other two branches, the no-op branch is not an explicit
906+
-- part of the STM syntax. It simply indicates that there are no
907+
-- alternative statements left to be executed. For example, when running
908+
-- right alternative of the `OrElse` statement or when running the catch
909+
-- handler of a `CatchStm` statement, there are no alternative statements
910+
-- available. This case is represented by the no-op branch.
911+
| NoOpStmA
873912

874913
data StmStack s b a where
875914
-- | Executing in the context of a top level 'atomically'.

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,32 +1391,54 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
13911391
{-# SCC "execAtomically.go.ThrowStm" #-} do
13921392
-- Revert all the TVar writes
13931393
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1394-
k0 $ StmTxAborted (Map.elems read) (toException e)
1394+
case ctl of
1395+
AtomicallyFrame -> do
1396+
k0 $ StmTxAborted (Map.elems read) (toException e)
1397+
1398+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1399+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1400+
-- Execute the left side in a new frame with an empty written set.
1401+
-- but preserve ones that were set prior to it, as specified in the
1402+
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
1403+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1404+
go ctl'' read Map.empty [] [] nextVid (h e)
1405+
1406+
BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1407+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1408+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1409+
1410+
BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1411+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1412+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1413+
1414+
CatchStm a h k ->
1415+
{-# SCC "execAtomically.go.ThrowStm" #-} do
1416+
-- Execute the left side in a new frame with an empty written set
1417+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
1418+
go ctl' read Map.empty [] [] nextVid a
13951419

13961420
Retry ->
1397-
{-# SCC "execAtomically.go.Retry" #-}
1398-
do
1399-
-- Always revert all the TVar writes for the retry
1400-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1401-
case ctl of
1402-
AtomicallyFrame -> do
1403-
-- Return vars read, so the thread can block on them
1404-
k0 $! StmTxBlocked $! Map.elems read
1405-
1406-
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1407-
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1408-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1409-
-- Execute the orElse right hand with an empty written set
1410-
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1411-
go ctl'' read Map.empty [] [] nextVid b
1412-
1413-
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1414-
{-# SCC "execAtomically.go.BranchFrame" #-} do
1415-
-- Retry makes sense only within a OrElse context. If it is a branch other than
1416-
-- OrElse left side, then bubble up the `retry` to the frame above.
1417-
-- Skip the continuation and propagate the retry into the outer frame
1418-
-- using the written set for the outer frame
1419-
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
1421+
{-# SCC "execAtomically.go.Retry" #-} do
1422+
-- Always revert all the TVar writes for the retry
1423+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1424+
case ctl of
1425+
AtomicallyFrame -> do
1426+
-- Return vars read, so the thread can block on them
1427+
k0 $! StmTxBlocked $! Map.elems read
1428+
1429+
BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1430+
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
1431+
-- Execute the orElse right hand with an empty written set
1432+
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1433+
go ctl'' read Map.empty [] [] nextVid b
1434+
1435+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1436+
{-# SCC "execAtomically.go.BranchFrame" #-} do
1437+
-- Retry makes sense only within a OrElse context. If it is a branch other than
1438+
-- OrElse left side, then bubble up the `retry` to the frame above.
1439+
-- Skip the continuation and propagate the retry into the outer frame
1440+
-- using the written set for the outer frame
1441+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
14201442

14211443
OrElse a b k ->
14221444
{-# SCC "execAtomically.go.OrElse" #-} do

io-sim/test/Test/IOSim.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ prop_stm_referenceSim t =
12491249
-- | Compare the behaviour of the STM reference operational semantics with
12501250
-- the behaviour of any 'MonadSTM' STM implementation.
12511251
--
1252-
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
1252+
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
12531253
=> SomeTerm -> m Property
12541254
prop_stm_referenceM (SomeTerm _tyrep t) = do
12551255
let (r1, _heap) = evalAtomically t

io-sim/test/Test/STM.hs

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ data Term (t :: Type) where
6767

6868
Return :: Expr t -> Term t
6969
Throw :: Expr a -> Term t
70+
Catch :: Term t -> Term t -> Term t
7071
Retry :: Term t
7172

7273
ReadTVar :: Name (TyVar t) -> Term t
@@ -296,7 +297,7 @@ deriving instance Show (NfTerm t)
296297
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
297298
--
298299
-- Compare the implementation of this against the operational semantics in
299-
-- Figure 4 in the paper. Note that @catch@ is not included.
300+
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
300301
--
301302
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
302303
evalTerm !env !heap !allocs term = case term of
@@ -309,6 +310,30 @@ evalTerm !env !heap !allocs term = case term of
309310
where
310311
e' = evalExpr env e
311312

313+
-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
314+
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
315+
Catch t1 t2 ->
316+
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of
317+
318+
-- Rule XSTM1
319+
-- M; heap, {} => return P; heap', allocs'
320+
-- --------------------------------------------------------
321+
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs'
322+
NfReturn v -> (NfReturn v, heap', allocs <> allocs')
323+
324+
-- Rule XSTM2
325+
-- M; heap, {} => throw P; heap', allocs'
326+
-- --------------------------------------------------------
327+
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
328+
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2
329+
330+
-- Rule XSTM3
331+
-- M; heap, {} => retry; heap', allocs'
332+
-- --------------------------------------------------------
333+
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
334+
NfRetry -> (NfRetry, heap, allocs)
335+
336+
312337
Retry -> (NfRetry, heap, allocs)
313338

314339
-- Rule READ
@@ -437,7 +462,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =
437462

438463
-- | Execute an STM 'Term' in the 'STM' monad.
439464
--
440-
execTerm :: (MonadSTM m, MonadThrow (STM m))
465+
execTerm :: (MonadSTM m, MonadCatch (STM m))
441466
=> ExecEnv m
442467
-> Term t
443468
-> STM m (ExecValue m t)
@@ -451,6 +476,8 @@ execTerm env t =
451476
let e' = execExpr env e
452477
throwSTM =<< snapshotExecValue e'
453478

479+
Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2
480+
454481
Retry -> retry
455482

456483
ReadTVar n -> do
@@ -491,7 +518,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
491518
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
492519
(snapshotExecValue =<< readTVar v)
493520

494-
execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
521+
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
495522
=> Term t -> m TxResult
496523
execAtomically t =
497524
toTxResult <$> try (atomically action')
@@ -657,7 +684,7 @@ genTerm env tyrep =
657684
Nothing)
658685
]
659686

660-
binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
687+
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]
661688

662689
bindTerm =
663690
sized $ \sz -> do
@@ -671,10 +698,15 @@ genTerm env tyrep =
671698
return (Bind t1 name t2)
672699

673700
orElseTerm =
674-
sized $ \sz -> resize (sz `div` 2) $
701+
scale (`div` 2) $
675702
OrElse <$> genTerm env tyrep
676703
<*> genTerm env tyrep
677704

705+
catchTerm =
706+
scale (`div` 2) $
707+
Catch <$> genTerm env tyrep
708+
<*> genTerm env tyrep
709+
678710
genSomeExpr :: GenEnv -> Gen SomeExpr
679711
genSomeExpr env =
680712
oneof'
@@ -713,6 +745,8 @@ shrinkTerm t =
713745
case t of
714746
Return e -> [Return e' | e' <- shrinkExpr e]
715747
Throw e -> [Throw e' | e' <- shrinkExpr e]
748+
Catch t1 t2 -> [t1, t2]
749+
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
716750
Retry -> []
717751
ReadTVar _ -> []
718752

@@ -721,12 +755,10 @@ shrinkTerm t =
721755
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]
722756

723757
Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
724-
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
725-
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
758+
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
726759

727760
OrElse t1 t2 -> [t1, t2]
728-
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
729-
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
761+
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]
730762

731763
shrinkExpr :: Expr t -> [Expr t]
732764
shrinkExpr ExprUnit = []
@@ -738,6 +770,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
738770
freeNamesTerm :: Term t -> Set NameId
739771
freeNamesTerm (Return e) = freeNamesExpr e
740772
freeNamesTerm (Throw e) = freeNamesExpr e
773+
-- The current generator of catch term ignores the argument passed to the
774+
-- handler.
775+
-- TODO: Correctly handle free names when the handler also binds a variable.
776+
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
741777
freeNamesTerm Retry = Set.empty
742778
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
743779
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
@@ -768,6 +804,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
768804
termSize :: Term a -> Int
769805
termSize Return{} = 1
770806
termSize Throw{} = 1
807+
termSize (Catch a b) = 1 + termSize a + termSize b
771808
termSize Retry{} = 1
772809
termSize ReadTVar{} = 1
773810
termSize WriteTVar{} = 1
@@ -778,6 +815,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
778815
termDepth :: Term a -> Int
779816
termDepth Return{} = 1
780817
termDepth Throw{} = 1
818+
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
781819
termDepth Retry{} = 1
782820
termDepth ReadTVar{} = 1
783821
termDepth WriteTVar{} = 1
@@ -790,6 +828,9 @@ showTerm p (Return e) = showParen (p > 10) $
790828
showString "return " . showExpr 11 e
791829
showTerm p (Throw e) = showParen (p > 10) $
792830
showString "throwSTM " . showExpr 11 e
831+
showTerm p (Catch t1 t2) = showParen (p > 9) $
832+
showTerm 10 t1 . showString " `catch` "
833+
. showTerm 10 t2
793834
showTerm _ Retry = showString "retry"
794835
showTerm p (ReadTVar n) = showParen (p > 10) $
795836
showString "readTVar " . showName n

0 commit comments

Comments
 (0)