Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAD-4738 stm monad catch instance #16

Merged
merged 1 commit into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1148,19 +1148,45 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
-- Rollback `TVar`s written since catch handler was installed
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)

BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the catch handler with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a


Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! (Map.elems read)
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
Expand Down
45 changes: 42 additions & 3 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -339,6 +340,31 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO


instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
where
-- Get a total handler from the given handler
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
fromHandler h e = case fromException e of
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
Just e' -> h e'

-- Masking is not required as STM actions are always run inside
-- `execAtomically` and behave as if masked. Also note that the default
-- implementation of `generalBracket` needs mask, and is part of `MonadThrow`.
generalBracket acquire release use = do
coot marked this conversation as resolved.
Show resolved Hide resolved
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
coot marked this conversation as resolved.
Show resolved Hide resolved
return (b, c)

instance Exceptions.MonadCatch (STM s) where
catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -867,9 +893,22 @@ data StmTxResult s a =
| StmTxAborted [SomeTVar s] SomeException


-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
-- can be a NoOp
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
-- | A branch indicates that an alternative statement is available in the current
-- context. For example, `OrElse` has two alternative statements, say "left"
-- and "right". While executing the left statement, `OrElseStmA` branch indicates
-- that the right branch is still available, in case the left statement fails.
data BranchStmA s a =
-- | `OrElse` statement with its 'right' alternative.
OrElseStmA (StmA s a)
-- | `CatchStm` statement with the 'catch' handler.
| CatchStmA (SomeException -> StmA s a)
-- | Unlike the other two branches, the no-op branch is not an explicit
-- part of the STM syntax. It simply indicates that there are no
-- alternative statements left to be executed. For example, when running
-- right alternative of the `OrElse` statement or when running the catch
-- handler of a `CatchStm` statement, there are no alternative statements
-- available. This case is represented by the no-op branch.
| NoOpStmA

data StmStack s b a where
-- | Executing in the context of a top level 'atomically'.
Expand Down
70 changes: 46 additions & 24 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,32 +1391,54 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)

BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the left side in a new frame with an empty written set
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
nfrisby marked this conversation as resolved.
Show resolved Hide resolved
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
Expand Down
2 changes: 1 addition & 1 deletion io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
59 changes: 50 additions & 9 deletions io-sim/test/Test/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ data Term (t :: Type) where

Return :: Expr t -> Term t
Throw :: Expr a -> Term t
Catch :: Term t -> Term t -> Term t
Retry :: Term t

ReadTVar :: Name (TyVar t) -> Term t
Expand Down Expand Up @@ -296,7 +297,7 @@ deriving instance Show (NfTerm t)
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
--
-- Compare the implementation of this against the operational semantics in
-- Figure 4 in the paper. Note that @catch@ is not included.
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
nfrisby marked this conversation as resolved.
Show resolved Hide resolved
--
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
evalTerm !env !heap !allocs term = case term of
Expand All @@ -309,6 +310,30 @@ evalTerm !env !heap !allocs term = case term of
where
e' = evalExpr env e

-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
coot marked this conversation as resolved.
Show resolved Hide resolved
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
Catch t1 t2 ->
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of

-- Rule XSTM1
-- M; heap, {} => return P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs'
NfReturn v -> (NfReturn v, heap', allocs <> allocs')
nfrisby marked this conversation as resolved.
Show resolved Hide resolved

-- Rule XSTM2
-- M; heap, {} => throw P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2

-- Rule XSTM3
-- M; heap, {} => retry; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
NfRetry -> (NfRetry, heap, allocs)


Retry -> (NfRetry, heap, allocs)

-- Rule READ
Expand Down Expand Up @@ -437,7 +462,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =

-- | Execute an STM 'Term' in the 'STM' monad.
--
execTerm :: (MonadSTM m, MonadThrow (STM m))
execTerm :: (MonadSTM m, MonadCatch (STM m))
=> ExecEnv m
-> Term t
-> STM m (ExecValue m t)
Expand All @@ -451,6 +476,8 @@ execTerm env t =
let e' = execExpr env e
throwSTM =<< snapshotExecValue e'

Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2
coot marked this conversation as resolved.
Show resolved Hide resolved

Retry -> retry

ReadTVar n -> do
Expand Down Expand Up @@ -491,7 +518,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
(snapshotExecValue =<< readTVar v)

execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> Term t -> m TxResult
execAtomically t =
toTxResult <$> try (atomically action')
Expand Down Expand Up @@ -657,7 +684,7 @@ genTerm env tyrep =
Nothing)
]

binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]

bindTerm =
sized $ \sz -> do
Expand All @@ -671,10 +698,15 @@ genTerm env tyrep =
return (Bind t1 name t2)

orElseTerm =
sized $ \sz -> resize (sz `div` 2) $
scale (`div` 2) $
OrElse <$> genTerm env tyrep
<*> genTerm env tyrep

catchTerm =
scale (`div` 2) $
Catch <$> genTerm env tyrep
<*> genTerm env tyrep

genSomeExpr :: GenEnv -> Gen SomeExpr
genSomeExpr env =
oneof'
Expand Down Expand Up @@ -713,6 +745,8 @@ shrinkTerm t =
case t of
Return e -> [Return e' | e' <- shrinkExpr e]
Throw e -> [Throw e' | e' <- shrinkExpr e]
Catch t1 t2 -> [t1, t2]
coot marked this conversation as resolved.
Show resolved Hide resolved
++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
Retry -> []
ReadTVar _ -> []

Expand All @@ -721,12 +755,10 @@ shrinkTerm t =
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]

Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

OrElse t1 t2 -> [t1, t2]
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

shrinkExpr :: Expr t -> [Expr t]
shrinkExpr ExprUnit = []
Expand All @@ -738,6 +770,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
freeNamesTerm :: Term t -> Set NameId
freeNamesTerm (Return e) = freeNamesExpr e
freeNamesTerm (Throw e) = freeNamesExpr e
-- The current generator of catch term ignores the argument passed to the
-- handler.
-- TODO: Correctly handle free names when the handler also binds a variable.
freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2
nfrisby marked this conversation as resolved.
Show resolved Hide resolved
freeNamesTerm Retry = Set.empty
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
Expand Down Expand Up @@ -768,6 +804,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
termSize :: Term a -> Int
termSize Return{} = 1
termSize Throw{} = 1
termSize (Catch a b) = 1 + termSize a + termSize b
termSize Retry{} = 1
termSize ReadTVar{} = 1
termSize WriteTVar{} = 1
Expand All @@ -778,6 +815,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
termDepth :: Term a -> Int
termDepth Return{} = 1
termDepth Throw{} = 1
termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b)
termDepth Retry{} = 1
termDepth ReadTVar{} = 1
termDepth WriteTVar{} = 1
Expand All @@ -790,6 +828,9 @@ showTerm p (Return e) = showParen (p > 10) $
showString "return " . showExpr 11 e
showTerm p (Throw e) = showParen (p > 10) $
showString "throwSTM " . showExpr 11 e
showTerm p (Catch t1 t2) = showParen (p > 9) $
showTerm 10 t1 . showString " `catch` "
. showTerm 10 t2
showTerm _ Retry = showString "retry"
showTerm p (ReadTVar n) = showParen (p > 10) $
showString "readTVar " . showName n
Expand Down