Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exception to STM Expr for testing #35

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
28 changes: 24 additions & 4 deletions io-sim/src/Control/Monad/IOSim/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -926,13 +926,33 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 =

ThrowStm e ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
-- Rollback `TVar`s written since catch handler was installed
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted [] (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Execute the left side in a new frame with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)
--
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the catch handler with an empty written set.
-- but preserve ones that were set prior to it, as specified in the
-- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package.
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a


Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
Expand Down
32 changes: 29 additions & 3 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ runSTM (STM k) = k ReturnStm
data StmA s a where
ReturnStm :: a -> StmA s a
ThrowStm :: SomeException -> StmA s a
CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b

NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b
LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b
Expand Down Expand Up @@ -322,6 +323,29 @@ instance MonadThrow (STM s) where
instance Exceptions.MonadThrow (STM s) where
throwM = MonadThrow.throwIO


instance MonadCatch (STM s) where

catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k
where
-- Get a total handler from the given handler
fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a
fromHandler h e = case fromException e of
Nothing -> throwIO e -- Rethrow the exception if handler does not handle it.
Just e' -> h e'

-- No need to consider masking for STM
generalBracket acquire release use = do
resource <- acquire
b <- use resource `catch` \e -> do
_ <- release resource (ExitCaseException e)
throwIO e
c <- release resource (ExitCaseSuccess b)
return (b, c)

instance Exceptions.MonadCatch (STM s) where
catch = MonadThrow.catch

instance MonadCatch (IOSim s) where
catch action handler =
IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k
Expand Down Expand Up @@ -857,9 +881,11 @@ data StmTxResult s a =
| StmTxAborted [SomeTVar s] SomeException


-- | OrElse/Catch give rise to an alternate right hand side branch. A right branch
-- can be a NoOp
data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA
-- | OrElse/Catch give rise to an alternate branch.
-- A branch of a branch is an empty one.
data BranchStmA s a = OrElseStmA (StmA s a)
| CatchStmA (SomeException -> StmA s a)
| NoOpStmA

data StmStack s b a where
-- | Executing in the context of a top level 'atomically'.
Expand Down
63 changes: 39 additions & 24 deletions io-sim/src/Control/Monad/IOSimPOR/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1174,32 +1174,47 @@ execAtomically time tid tlbl nextVid0 action0 k0 =
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Revert all the TVar writes
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
k0 $ StmTxAborted (Map.elems read) (toException e)
case ctl of
AtomicallyFrame -> do
k0 $ StmTxAborted (Map.elems read) (toException e)

BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid (h e)
--
BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e)

CatchStm a h k ->
{-# SCC "execAtomically.go.ThrowStm" #-} do
-- Execute the left side in a new frame with an empty written set
let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl
go ctl' read Map.empty [] [] nextVid a

Retry ->
{-# SCC "execAtomically.go.Retry" #-}
do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry
{-# SCC "execAtomically.go.Retry" #-} do
-- Always revert all the TVar writes for the retry
!_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written
case ctl of
AtomicallyFrame -> do
-- Return vars read, so the thread can block on them
k0 $! StmTxBlocked $! Map.elems read

BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do
-- Execute the orElse right hand with an empty written set
let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl'
go ctl'' read Map.empty [] [] nextVid b

BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' ->
{-# SCC "execAtomically.go.BranchFrame" #-} do
-- Retry makes sense only within a OrElse context. If it is a branch other than
-- OrElse left side, then bubble up the `retry` to the frame above.
-- Skip the continuation and propagate the retry into the outer frame
-- using the written set for the outer frame
go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry

OrElse a b k ->
{-# SCC "execAtomically.go.OrElse" #-} do
Expand Down
2 changes: 1 addition & 1 deletion io-sim/test/Test/IOSim.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ prop_stm_referenceSim t =
-- | Compare the behaviour of the STM reference operational semantics with
-- the behaviour of any 'MonadSTM' STM implementation.
--
prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m)
prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> SomeTerm -> m Property
prop_stm_referenceM (SomeTerm _tyrep t) = do
let (r1, _heap) = evalAtomically t
Expand Down
67 changes: 57 additions & 10 deletions io-sim/test/Test/STM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module Test.STM where

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe, maybeToList)
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Type.Equality
Expand Down Expand Up @@ -68,6 +68,7 @@ data Term (t :: Type) where

Return :: Expr t -> Term t
Throw :: Expr a -> Term t
Catch :: Term t -> SomeException -> Term t -> Term t
Retry :: Term t

ReadTVar :: Name (TyVar t) -> Term t
Expand Down Expand Up @@ -267,6 +268,7 @@ data ImmValue where
ImmValVar :: ImmValue -> ImmValue
deriving (Eq, Show)


-- | In the execution in real STM transactions are aborted by throwing an
-- exception.
--
Expand Down Expand Up @@ -297,7 +299,7 @@ deriving instance Show (NfTerm t)
-- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'.
--
-- Compare the implementation of this against the operational semantics in
-- Figure 4 in the paper. Note that @catch@ is not included.
-- Figure 4 in the paper including the `Catch` semantics from the Appendix A.
--
evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs)
evalTerm !env !heap !allocs term = case term of
Expand All @@ -310,6 +312,37 @@ evalTerm !env !heap !allocs term = case term of
where
e' = evalExpr env e

-- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of
-- <https://research.microsoft.com/en-us/um/people/simonpj/papers/stm/stm.pdf>
Catch t1 exc t2 ->
let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of

-- Rule XSTM1
-- M; heap, {} => return P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[return P]; heap', allocs'
NfReturn v -> (NfReturn v, heap', allocs <> allocs')

-- Rule XSTM2
-- M; heap, {} => throw P; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs'
NfThrow v ->
-- v should be compared to exception
case fromException exc of
-- TODO: Add eqValue for value
Just (ImmValInt 0) ->
evalTerm env (heap <> allocs') (allocs <> allocs') t2
-- Exception is not handled, bubble it up
_otherwise -> (NfThrow v, heap, allocs)

-- Rule XSTM3
-- M; heap, {} => retry; heap', allocs'
-- --------------------------------------------------------
-- S[catch M N]; heap, allocs => S[retry]; heap, allocs
NfRetry -> (NfRetry, heap, allocs)


Retry -> (NfRetry, heap, allocs)

-- Rule READ
Expand Down Expand Up @@ -438,7 +471,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) =

-- | Execute an STM 'Term' in the 'STM' monad.
--
execTerm :: (MonadSTM m, MonadThrow (STM m))
execTerm :: (MonadSTM m, MonadCatch (STM m))
=> ExecEnv m
-> Term t
-> STM m (ExecValue m t)
Expand All @@ -452,6 +485,8 @@ execTerm env t =
let e' = execExpr env e
throwSTM =<< snapshotExecValue e'

Catch t1 _ t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2

Retry -> retry

ReadTVar n -> do
Expand Down Expand Up @@ -492,7 +527,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x)
snapshotExecValue (ExecValVar v _) = fmap ImmValVar
(snapshotExecValue =<< readTVar v)

execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m)
execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m)
=> Term t -> m TxResult
execAtomically t =
toTxResult <$> try (atomically action')
Expand Down Expand Up @@ -658,7 +693,7 @@ genTerm env tyrep =
Nothing)
]

binTerm = frequency [ (2, bindTerm), (1, orElseTerm)]
binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)]

bindTerm =
sized $ \sz -> do
Expand All @@ -672,10 +707,16 @@ genTerm env tyrep =
return (Bind t1 name t2)

orElseTerm =
sized $ \sz -> resize (sz `div` 2) $
scale (`div` 2) $
OrElse <$> genTerm env tyrep
<*> genTerm env tyrep

catchTerm =
scale (`div` 2) $
Catch <$> genTerm env tyrep
<*> pure (toException $ ImmValInt 0) -- TODO: 0 is treated as an exception value, generalize it later
<*> genTerm env tyrep

genSomeExpr :: GenEnv -> Gen SomeExpr
genSomeExpr env =
oneof'
Expand Down Expand Up @@ -714,6 +755,8 @@ shrinkTerm t =
case t of
Return e -> [Return e' | e' <- shrinkExpr e]
Throw e -> [Throw e' | e' <- shrinkExpr e]
Catch t1 exc t2 -> [t1, t2]
++ [Catch t1' exc t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)]
Retry -> []
ReadTVar _ -> []

Expand All @@ -722,12 +765,10 @@ shrinkTerm t =
NewTVar e -> [NewTVar e' | e' <- shrinkExpr e]

Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ]
++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ]
++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ]
++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

OrElse t1 t2 -> [t1, t2]
++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ]
++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ]
++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ]

shrinkExpr :: Expr t -> [Expr t]
shrinkExpr ExprUnit = []
Expand All @@ -739,6 +780,7 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = []
freeNamesTerm :: Term t -> Set NameId
freeNamesTerm (Return e) = freeNamesExpr e
freeNamesTerm (Throw e) = freeNamesExpr e
freeNamesTerm (Catch t1 exc t2) = freeNamesTerm t1 <> freeNamesTerm t2
freeNamesTerm Retry = Set.empty
freeNamesTerm (ReadTVar n) = Set.singleton (nameId n)
freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e
Expand Down Expand Up @@ -769,6 +811,7 @@ prop_genSomeTerm (SomeTerm tyrep term) =
termSize :: Term a -> Int
termSize Return{} = 1
termSize Throw{} = 1
termSize (Catch a _ b) = 1 + termSize a + termSize b
termSize Retry{} = 1
termSize ReadTVar{} = 1
termSize WriteTVar{} = 1
Expand All @@ -779,6 +822,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b
termDepth :: Term a -> Int
termDepth Return{} = 1
termDepth Throw{} = 1
termDepth (Catch a _ b) = 1 + max (termDepth a) (termDepth b)
termDepth Retry{} = 1
termDepth ReadTVar{} = 1
termDepth WriteTVar{} = 1
Expand All @@ -791,6 +835,9 @@ showTerm p (Return e) = showParen (p > 10) $
showString "return " . showExpr 11 e
showTerm p (Throw e) = showParen (p > 10) $
showString "throwSTM " . showExpr 11 e
showTerm p (Catch t1 _ t2) = showParen (p > 9) $
showTerm 10 t1 . showString " `catch` "
. showTerm 10 t2
showTerm _ Retry = showString "retry"
showTerm p (ReadTVar n) = showParen (p > 10) $
showString "readTVar " . showName n
Expand Down