Skip to content

Commit 7c22a2a

Browse files
Add catch semantics to STM
Add catch handler as branch to support catch
1 parent c720399 commit 7c22a2a

File tree

4 files changed

+77
-9
lines changed

4 files changed

+77
-9
lines changed

io-sim/src/Control/Monad/IOSim/Internal.hs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -899,10 +899,33 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =
899899
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)
900900

901901
ThrowStm e ->
902+
{-# SCC "execAtomically.go.ThrowStm" #-}
903+
case ctl of
904+
AtomicallyFrame -> do
905+
-- Revert all the TVar writes
906+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
907+
k0 $ StmTxAborted [] (toException e)
908+
909+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
910+
{-# SCC "execAtomically.go.branchFrame" #-} do
911+
-- Revert all the TVar writes within this orElse
912+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
913+
-- Execute the catch handler with an empty written set
914+
let ctl'' = BranchFrame EmptyStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
915+
go ctl'' read Map.empty [] [] nextVid (h e)
916+
--
917+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
918+
{-# SCC "execAtomically.go.branchFrame" #-} do
919+
-- Revert all the TVar writes within this orElse
920+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
921+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
922+
923+
CatchStm a h k ->
902924
{-# SCC "execAtomically.go.ThrowStm" #-} do
903-
-- Revert all the TVar writes
904-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
905-
k0 $ StmTxAborted [] (toException e)
925+
-- Execute the left side in a new frame with an empty written set
926+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
927+
go ctl' read Map.empty [] [] nextVid a
928+
906929

907930
Retry ->
908931
{-# SCC "execAtomically.go.Retry" #-}

io-sim/src/Control/Monad/IOSim/Types.hs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ runSTM (STM k) = k ReturnStm
177177
data StmA s a where
178178
ReturnStm :: a -> StmA s a
179179
ThrowStm :: SomeException -> StmA s a
180+
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b
180181

181182
NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
182183
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
@@ -315,6 +316,25 @@ instance MonadThrow (STM s) where
315316
instance Exceptions.MonadThrow (STM s) where
316317
throwM = MonadThrow.throwIO
317318

319+
instance MonadCatch (STM s) where
320+
321+
catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . handler') k
322+
where
323+
handler' e = case fromException e of
324+
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
325+
Just e' -> handler e'
326+
327+
generalBracket acquire release use = do
328+
resource <- acquire
329+
b <- use resource `catch` \e -> do
330+
_ <- release resource (ExitCaseException e)
331+
throwIO e
332+
c <- release resource (ExitCaseSuccess b)
333+
return (b, c)
334+
335+
instance Exceptions.MonadCatch (STM s) where
336+
catch = MonadThrow.catch
337+
318338
instance MonadCatch (IOSim s) where
319339
catch action handler =
320340
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
@@ -846,9 +866,11 @@ data StmTxResult s a =
846866
| StmTxAborted [SomeTVar s] SomeException
847867

848868

849-
-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
850-
-- can be a NoOp
851-
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
869+
-- | OrElse/Catch give rise to an alternate branch. A branch of a branch is an
870+
-- empty one.
871+
data BranchStmA s a = OrElseStmA (StmA s a)
872+
| CatchStmA (SomeException -> StmA s a)
873+
| NoOpStmA
852874

853875
data StmStack s b a where
854876
-- | Executing in the context of a top level 'atomically'.

io-sim/src/Control/Monad/IOSimPOR/Internal.hs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,10 +1110,32 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
11101110
go ctl' read written' writtenSeq' createdSeq' nextVid (k x)
11111111

11121112
ThrowStm e ->
1113+
{-# SCC "execAtomically.go.ThrowStm" #-}
1114+
case ctl of
1115+
AtomicallyFrame -> do
1116+
-- Revert all the TVar writes
1117+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1118+
k0 $ StmTxAborted [] (toException e)
1119+
1120+
BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1121+
{-# SCC "execAtomically.go.branchFrame" #-} do
1122+
-- Revert all the TVar writes within this orElse
1123+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1124+
-- Execute the catch handler with an empty written set
1125+
let ctl'' = BranchFrame EmptyStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
1126+
go ctl'' read Map.empty [] [] nextVid (h e)
1127+
--
1128+
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
1129+
{-# SCC "execAtomically.go.branchFrame" #-} do
1130+
-- Revert all the TVar writes within this orElse
1131+
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1132+
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)
1133+
1134+
CatchStm a h k ->
11131135
{-# SCC "execAtomically.go.ThrowStm" #-} do
1114-
-- Revert all the TVar writes
1115-
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
1116-
k0 $ StmTxAborted (Map.elems read) (toException e)
1136+
-- Execute the left side in a new frame with an empty written set
1137+
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
1138+
go ctl' read Map.empty [] [] nextVid a
11171139

11181140
Retry ->
11191141
{-# SCC "execAtomically.go.Retry" #-}

io-sim/test/Test/STM.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ data Term (t :: Type) where
6767

6868
Return :: Expr t -> Term t
6969
Throw :: Expr a -> Term t
70+
Catch :: Term t -> Expr a -> Term t -> Term t
7071
Retry :: Term t
7172

7273
ReadTVar :: Name (TyVar t) -> Term t

0 commit comments

Comments
 (0)