diff --git a/src/Dex/Foreign/Context.hs b/src/Dex/Foreign/Context.hs index c793b464c..00f7d8deb 100644 --- a/src/Dex/Foreign/Context.hs +++ b/src/Dex/Foreign/Context.hs @@ -29,6 +29,7 @@ import TopLevel import Parser (parseExpr, exprAsModule) import Env hiding (Tag) import PPrint +import Err import Dex.Foreign.Util @@ -46,19 +47,19 @@ dexCreateContext = do let evalConfig = EvalConfig LLVM Nothing Nothing maybePreludeEnv <- evalPrelude evalConfig preludeSource case maybePreludeEnv of - Right preludeEnv -> toStablePtr $ Context evalConfig preludeEnv - Left err -> nullPtr <$ setError ("Failed to initialize standard library: " ++ pprint err) + Success preludeEnv -> toStablePtr $ Context evalConfig preludeEnv + Failure err -> nullPtr <$ setError ("Failed to initialize standard library: " ++ pprint err) where - evalPrelude :: EvalConfig -> String -> IO (Either Err TopStateEx) + evalPrelude :: EvalConfig -> String -> IO (Except TopStateEx) evalPrelude opts sourceText = do (results, env) <- runInterblockM opts initTopState $ map snd <$> evalSourceText sourceText return $ env `unlessError` results where unlessError :: TopStateEx -> [Result] -> Except TopStateEx - result `unlessError` [] = Right result - _ `unlessError` ((Result _ (Left err)):_) = Left err - result `unlessError` (_:t ) = result `unlessError` t + result `unlessError` [] = Success result + _ `unlessError` ((Result _ (Failure err)):_) = Failure err + result `unlessError` (_:t ) = result `unlessError` t dexDestroyContext :: Ptr Context -> IO () dexDestroyContext = freeStablePtr . castPtrToStablePtr . castPtr @@ -68,7 +69,7 @@ dexEval ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr source <- peekCString sourcePtr (results, finalEnv) <- runInterblockM evalConfig env $ evalSourceText source - let anyError = asum $ fmap (\case (_, Result _ (Left err)) -> Just err; _ -> Nothing) results + let anyError = asum $ fmap (\case (_, Result _ (Failure err)) -> Just err; _ -> Nothing) results case anyError of Nothing -> toStablePtr $ Context evalConfig finalEnv Just err -> setError (pprint err) $> nullPtr @@ -90,22 +91,23 @@ dexEvalExpr ctxPtr sourcePtr = do Context evalConfig env <- fromStablePtr ctxPtr source <- peekCString sourcePtr case parseExpr source of - Right expr -> do + Success expr -> do let (v, m) = exprAsModule expr let block = SourceBlock 0 0 LogNothing source (RunModule m) Nothing (Result [] maybeErr, newState) <- runInterblockM evalConfig env $ evalSourceBlock block case maybeErr of - Right () -> do - let Right (AtomBinderInfo _ (LetBound _ (Atom atom))) = lookupSourceName newState v + Success () -> do + let Success (AtomBinderInfo _ (LetBound _ (Atom atom))) = + lookupSourceName newState v toStablePtr atom - Left err -> setError (pprint err) $> nullPtr - Left err -> setError (pprint err) $> nullPtr + Failure err -> setError (pprint err) $> nullPtr + Failure err -> setError (pprint err) $> nullPtr dexLookup :: Ptr Context -> CString -> IO (Ptr Atom) dexLookup ctxPtr namePtr = do Context _ env <- fromStablePtr ctxPtr name <- peekCString namePtr - case lookupSourceName env $ fromString name of - Right (AtomBinderInfo _ (LetBound _ (Atom atom))) -> toStablePtr atom - Left _ -> setError "Unbound name" $> nullPtr - Right _ -> setError "Looking up an expression" $> nullPtr + case lookupSourceName env (fromString name) of + Success (AtomBinderInfo _ (LetBound _ (Atom atom))) -> toStablePtr atom + Failure _ -> setError "Unbound name" $> nullPtr + Success _ -> setError "Looking up an expression" $> nullPtr diff --git a/src/dex.hs b/src/dex.hs index 3e3527ed4..077830908 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -10,7 +10,7 @@ import System.Console.Haskeline import System.Exit import Control.Monad import Control.Monad.State.Strict -import Options.Applicative +import Options.Applicative hiding (Success, Failure) import Text.PrettyPrint.ANSI.Leijen (text, hardline) import System.Posix.Terminal (queryTerminal) import System.Posix.IO (stdOutput) @@ -26,6 +26,7 @@ import Resources import TopLevel import Parser hiding (Parser) import Env (envNames) +import Err import Export #ifdef DEX_LIVE import RenderHtml @@ -68,10 +69,10 @@ runMode evalMode preludeFile opts = do ExportMode dexPath objPath -> do results <- evalInterblockM opts env $ map snd <$> evalFile dexPath let outputs = foldMap (\(Result outs _) -> outs) results - let errors = foldMap (\case (Result _ (Left err)) -> [err]; _ -> []) results + let errors = foldMap (\case (Result _ (Failure err)) -> [err]; _ -> []) results putStr $ foldMap (nonEmptyNewline . pprint) errors let exportedFuns = foldMap (\case (ExportedFun name f) -> [(name, f)]; _ -> []) outputs - unless (backendName opts == LLVM) $ liftEitherIO $ + unless (backendName opts == LLVM) $ throw CompilerErr "Export only supported with the LLVM CPU backend" TopStateEx env' <- return env exportFunctions objPath exportedFuns $ topBindings $ topStateD env' @@ -95,7 +96,7 @@ replLoop prompt = do sourceBlock <- readMultiline prompt parseTopDeclRepl env <- lift getTopStateEx result <- lift $ evalSourceBlock sourceBlock - case result of Result _ (Left _) -> lift $ setTopStateEx env + case result of Result _ (Failure _) -> lift $ setTopStateEx env _ -> return () liftIO $ putStrLn $ pprint result @@ -112,8 +113,8 @@ dexCompletions (line, _) = do return (rest, completions) liftErrIO :: MonadIO m => Except a -> m a -liftErrIO (Left err) = liftIO $ putStrLn (pprint err) >> exitFailure -liftErrIO (Right x) = return x +liftErrIO (Failure err) = liftIO $ putStrLn (pprint err) >> exitFailure +liftErrIO (Success ans) = return ans readMultiline :: (MonadException m, MonadIO m) => String -> (String -> Maybe a) -> InputT m a diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 973949992..56df2e4cf 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -64,7 +64,7 @@ import PPrint () import Util (bindM2, scanM, restructure) newtype BuilderT m a = BuilderT (ReaderT BuilderEnvR (CatT BuilderEnvC m) a) - deriving (Functor, Applicative, Monad, MonadIO, MonadFail, Alternative) + deriving (Functor, Applicative, Monad, MonadIO, MonadFail, Fallible, Alternative) type Builder = BuilderT Identity type BuilderEnv = (BuilderEnvR, BuilderEnvC) @@ -140,7 +140,7 @@ freshNestedBindersRec substEnv (Nest b bs) = do vs <- freshNestedBindersRec (substEnv <> b @> SubstVal (Var v)) bs return $ Nest v vs -buildPi :: (MonadError Err m, MonadBuilder m) +buildPi :: (Fallible m, MonadBuilder m) => Binder -> (Atom -> m (Arrow, Type)) -> m Atom buildPi b f = do scope <- getScope @@ -582,8 +582,8 @@ checkBuilder x = do let globals = freeVars x `envDiff` scope eff <- getAllowedEffects case checkType (scope <> globals) eff x of - Left e -> error $ pprint e - Right () -> return x + Failure e -> error $ pprint e + Success () -> return x isSingletonType :: Type -> Bool isSingletonType ty = case singletonTypeVal ty of @@ -676,16 +676,6 @@ instance (Monoid env, MonadCat env m) => MonadCat env (BuilderT m) where extend env' return (ans, scopeEnv) -instance MonadError e m => MonadError e (BuilderT m) where - throwError = lift . throwError - catchError m catch = do - envC <- builderLook - envR <- builderAsk - (ans, envC') <- lift $ runBuilderT' m (envR, envC) - `catchError` (\e -> runBuilderT' (catch e) (envR, envC)) - builderExtend envC' - return ans - instance MonadReader r m => MonadReader r (BuilderT m) where ask = lift ask local r m = do diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs index 01aa6d062..849e7b647 100644 --- a/src/lib/Cat.hs +++ b/src/lib/Cat.hs @@ -24,8 +24,11 @@ import Control.Monad.Writer import Control.Monad.Identity import Control.Monad.Except hiding (Except) +import Err + newtype CatT env m a = CatT (StateT (env, env) m a) - deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFail, Alternative) + deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFail, Alternative, + Fallible) type Cat env = CatT env Identity @@ -75,14 +78,6 @@ instance MonadCat env m => MonadCat env (ExceptT e m) where Left err -> throwError err Right x -> return (x, env) -instance (Monoid env, MonadError e m) => MonadError e (CatT env m) where - throwError = lift . throwError - catchError m catch = do - env <- look - (ans, env') <- lift $ runCatT m env `catchError` (\e -> runCatT (catch e) env) - extend env' - return ans - instance (Monoid env, MonadReader r m) => MonadReader r (CatT env m) where ask = lift ask local f m = do diff --git a/src/lib/Err.hs b/src/lib/Err.hs index e04f4c20a..2044e8793 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -4,18 +4,26 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} -module Err (Err (..), ErrType (..), Except, SrcPos, SrcCtx, - throw, throwIf, modifyErr, MonadErr, - addContext, addSrcContext, catchIOExcept, liftEitherIO, - assertEq, ignoreExcept, pprint, docAsStr, asCompilerErr) where +module Err (Err (..), Errs (..), ErrType (..), Except (..), ErrCtx (..), + SrcPosCtx, SrcTextCtx, SrcPos, + Fallible (..), FallibleM (..), HardFailM (..), + runHardFail, throw, throwIf, + addContext, addSrcContext, addSrcTextContext, + catchIOExcept, liftExcept, + assertEq, ignoreExcept, pprint, docAsStr, asCompilerErr, + FallibleApplicativeWrapper, traverseMergingErrs) where import Control.Exception hiding (throw) +import Control.Applicative import Control.Monad -import Control.Monad.Except hiding (Except) +import Control.Monad.Identity +import Control.Monad.State.Strict +import Control.Monad.Reader import Data.Text (unpack) import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -23,7 +31,11 @@ import GHC.Stack import System.Environment import System.IO.Unsafe -data Err = Err ErrType SrcCtx String deriving (Show, Eq) +-- === core API === + +data Err = Err ErrType ErrCtx String deriving (Show, Eq) +newtype Errs = Errs [Err] deriving (Show, Eq, Semigroup, Monoid) + data ErrType = NoErr | ParseErr | TypeErr @@ -42,61 +54,164 @@ data ErrType = NoErr | ZipErr | EscapedNameErr | ModuleImportErr + | MonadFailErr deriving (Show, Eq) -type Except = Either Err +type SrcPosCtx = Maybe SrcPos +type SrcTextCtx = Maybe (Int, String) -- Int is the offset in the source file +data ErrCtx = ErrCtx + { srcTextCtx :: SrcTextCtx + , srcPosCtx :: SrcPosCtx + , messageCtx :: [String] } + deriving (Show, Eq) type SrcPos = (Int, Int) -type SrcCtx = Maybe SrcPos -type MonadErr = MonadError Err +class MonadFail m => Fallible m where + throwErrs :: Errs -> m a + addErrCtx :: ErrCtx -> m a -> m a + +-- We have this in its own class because IO and `Except` can't implement it +-- (but FallibleM can) +class Fallible m => CtxReader m where + getErrCtx :: m ErrCtx + +-- We have this in its own class because StateT can't implement it +-- (but FallibleM, Except and IO all can) +class Fallible m => FallibleApplicative m where + mergeErrs :: m a -> m b -> m (a, b) + +newtype FallibleM a = + FallibleM { fromFallibleM :: ReaderT ErrCtx Except a } + deriving (Functor, Applicative, Monad) + +instance Fallible FallibleM where + throwErrs errs = FallibleM $ lift $ throwErrs errs + addErrCtx ctx (FallibleM m) = FallibleM $ local (<> ctx) m + +instance FallibleApplicative FallibleM where + mergeErrs (FallibleM (ReaderT f1)) (FallibleM (ReaderT f2)) = + FallibleM $ ReaderT \ctx -> mergeErrs (f1 ctx) (f2 ctx) + +instance CtxReader FallibleM where + getErrCtx = FallibleM ask + +instance Fallible IO where + throwErrs errs = throwIO errs + addErrCtx ctx m = do + result <- catchIOExcept m + liftExcept $ addErrCtx ctx result + +instance FallibleApplicative IO where + mergeErrs m1 m2 = do + result1 <- catchIOExcept m1 + result2 <- catchIOExcept m2 + liftExcept $ mergeErrs result1 result2 + +-- === Except type === + +-- Except is isomorphic to `Either Errs` but having a distinct type makes it +-- easier to debug type errors. + +data Except a = + Failure Errs + | Success a + deriving (Show, Eq) + +instance Functor Except where + fmap = liftM + +instance Applicative Except where + pure = return + liftA2 = liftM2 + +instance Monad Except where + return = Success + Failure errs >>= _ = Failure errs + Success x >>= f = f x + +-- === FallibleApplicativeWrapper === + +-- Wraps a Fallible monad, presenting an applicative interface that sequences +-- actions using the error-concatenating `mergeErrs` instead of the default +-- abort-on-failure sequencing. + +newtype FallibleApplicativeWrapper m a = + FallibleApplicativeWrapper { fromFallibleApplicativeWrapper :: m a } + deriving (Functor) + +instance FallibleApplicative m => Applicative (FallibleApplicativeWrapper m) where + pure x = FallibleApplicativeWrapper $ pure x + liftA2 f (FallibleApplicativeWrapper m1) (FallibleApplicativeWrapper m2) = + FallibleApplicativeWrapper $ fmap (uncurry f) (mergeErrs m1 m2) + +-- === HardFail === + +-- Implements Fallible by crashing. Used in type querying when we want to avoid +-- work by trusting decl annotations and skipping the checks. +newtype HardFailM a = + HardFailM { runHardFail' :: Identity a } + deriving (Functor, Applicative, Monad) + +runHardFail :: HardFailM a -> a +runHardFail m = runIdentity $ runHardFail' m + +instance MonadFail HardFailM where + fail s = error s + +instance Fallible HardFailM where + throwErrs errs = error $ pprint errs + addErrCtx _ cont = cont + +instance FallibleApplicative HardFailM where + mergeErrs cont1 cont2 = (,) <$> cont1 <*> cont2 + +-- === convenience layer === -throw :: MonadErr m => ErrType -> String -> m a -throw e s = throwError $ Err e Nothing s +throw :: Fallible m => ErrType -> String -> m a +throw errTy s = throwErrs $ Errs [Err errTy mempty s] -throwIf :: MonadErr m => Bool -> ErrType -> String -> m () +throwIf :: Fallible m => Bool -> ErrType -> String -> m () throwIf True e s = throw e s throwIf False _ _ = return () -modifyErr :: MonadError e m => m a -> (e -> e) -> m a -modifyErr m f = catchError m \e -> throwError (f e) +addContext :: Fallible m => String -> m a -> m a +addContext s m = addErrCtx (mempty {messageCtx = [s]}) m -asCompilerErr :: MonadErr m => m a -> m a -asCompilerErr m = - modifyErr m (\(Err _ c msg) -> Err CompilerErr c msg) +addSrcContext :: Fallible m => SrcPosCtx -> m a -> m a +addSrcContext ctx m = addErrCtx (mempty {srcPosCtx = ctx}) m -addContext :: MonadErr m => String -> m a -> m a -addContext s m = modifyErr m \(Err e p s') -> Err e p (s' ++ "\n" ++ s) +addSrcTextContext :: Fallible m => Int -> String -> m a -> m a +addSrcTextContext offset text m = + addErrCtx (mempty {srcTextCtx = Just (offset, text)}) m -addSrcContext :: MonadErr m => SrcCtx -> m a -> m a -addSrcContext ctx m = modifyErr m updateErr - where - updateErr :: Err -> Err - updateErr (Err e ctx' s) = case ctx' of Nothing -> Err e ctx s - Just _ -> Err e ctx' s - -catchIOExcept :: (MonadIO m , MonadErr m) => IO a -> m a -catchIOExcept m = (liftIO >=> liftEither) $ (liftM Right m) `catches` - [ Handler \(e::Err) -> return $ Left e - , Handler \(e::IOError) -> return $ Left $ Err DataIOErr Nothing $ show e - , Handler \(e::SomeException) -> return $ Left $ Err CompilerErr Nothing $ show e +catchIOExcept :: MonadIO m => IO a -> m (Except a) +catchIOExcept m = liftIO $ (liftM Success m) `catches` + [ Handler \(e::Errs) -> return $ Failure e + , Handler \(e::IOError) -> return $ Failure $ Errs [Err DataIOErr mempty $ show e] + , Handler \(e::SomeException) -> return $ Failure $ Errs [Err CompilerErr mempty $ show e] ] -liftEitherIO :: (Exception e, MonadIO m) => Either e a -> m a -liftEitherIO (Left err) = liftIO $ throwIO err -liftEitherIO (Right x ) = return x +liftExcept :: Fallible m => Except a -> m a +liftExcept (Failure errs) = throwErrs errs +liftExcept (Success ans) = return ans ignoreExcept :: HasCallStack => Except a -> a -ignoreExcept (Left e) = error $ pprint e -ignoreExcept (Right x) = x +ignoreExcept (Failure e) = error $ pprint e +ignoreExcept (Success x) = x -assertEq :: (HasCallStack, MonadErr m, Show a, Pretty a, Eq a) => a -> a -> String -> m () +assertEq :: (HasCallStack, Fallible m, Show a, Pretty a, Eq a) => a -> a -> String -> m () assertEq x y s = if x == y then return () else throw CompilerErr msg where msg = "assertion failure (" ++ s ++ "):\n" ++ pprint x ++ " != " ++ pprint y ++ "\n\n" ++ prettyCallStack callStack ++ "\n" +-- TODO: think about the best way to handle these. This is just a +-- backwards-compatibility shim. +asCompilerErr :: Fallible m => m a -> m a +asCompilerErr cont = addContext "(This is a compiler error!)" cont + -- === small pretty-printing utils === -- These are here instead of in PPrint.hs for import cycle reasons @@ -110,15 +225,53 @@ layout :: LayoutOptions layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions where unbounded = unsafePerformIO $ (Just "1"==) <$> lookupEnv "DEX_PPRINT_UNBOUNDED" +traverseMergingErrs :: (Traversable f, FallibleApplicative m) + => (a -> m b) -> f a -> m (f b) +traverseMergingErrs f xs = + fromFallibleApplicativeWrapper $ traverse (\x -> FallibleApplicativeWrapper $ f x) xs + -- === instances === -instance MonadFail (Either Err) where - fail s = Left $ Err CompilerErr Nothing s +instance MonadFail FallibleM where + fail s = throw MonadFailErr s + +instance Fallible Except where + throwErrs errs = Failure errs + + addErrCtx _ (Success ans) = Success ans + addErrCtx ctx (Failure (Errs errs)) = + Failure $ Errs [Err errTy (ctx <> ctx') s | Err errTy ctx' s <- errs] + +instance FallibleApplicative Except where + mergeErrs (Success x) (Success y) = Success (x, y) + mergeErrs x y = Failure (getErrs x <> getErrs y) + where getErrs :: Except a -> Errs + getErrs = \case Failure e -> e + Success _ -> mempty + +instance MonadFail Except where + fail s = Failure $ Errs [Err CompilerErr mempty s] -instance Exception Err +instance Exception Errs instance Pretty Err where - pretty (Err e _ s) = pretty e <> pretty s + pretty (Err e ctx s) = pretty e <> pretty s <> prettyCtx + -- TODO: figure out a more uniform way to newlines + where prettyCtx = case ctx of + ErrCtx _ Nothing [] -> mempty + _ -> hardline <> pretty ctx + +instance Pretty ErrCtx where + pretty (ErrCtx maybeTextCtx maybePosCtx messages) = + -- The order of messages is outer-scope-to-inner-scope, but we want to print + -- them starting the other way around (Not for a good reason. It's just what + -- we've always done.) + prettyLines (reverse messages) <> highlightedSource + where + highlightedSource = case (maybeTextCtx, maybePosCtx) of + (Just (offset, text), Just (start, stop)) -> + hardline <> pretty (highlightRegion (start - offset, stop - offset) text) + _ -> mempty instance Pretty ErrType where pretty e = case e of @@ -144,3 +297,79 @@ instance Pretty ErrType where ZipErr -> "Zipping error" EscapedNameErr -> "Escaped name" ModuleImportErr -> "Module import error" + MonadFailErr -> "MonadFail error (internal error)" + +instance Fallible m => Fallible (ReaderT r m) where + throwErrs errs = lift $ throwErrs errs + addErrCtx ctx (ReaderT f) = ReaderT \r -> addErrCtx ctx $ f r + +instance FallibleApplicative m => FallibleApplicative (ReaderT r m) where + mergeErrs (ReaderT f1) (ReaderT f2) = + ReaderT \r -> mergeErrs (f1 r) (f2 r) + +instance CtxReader m => CtxReader (ReaderT r m) where + getErrCtx = lift getErrCtx + +instance Pretty Errs where + pretty (Errs [err]) = pretty err + pretty (Errs errs) = prettyLines errs + +instance Fallible m => Fallible (StateT s m) where + throwErrs errs = lift $ throwErrs errs + addErrCtx ctx (StateT f) = StateT \s -> addErrCtx ctx $ f s + +instance CtxReader m => CtxReader (StateT s m) where + getErrCtx = lift getErrCtx + +instance Semigroup ErrCtx where + ErrCtx text pos ctxStrs <> ErrCtx text' pos' ctxStrs' = + ErrCtx (leftmostJust text text') + (rightmostJust pos pos' ) + (ctxStrs <> ctxStrs') +instance Monoid ErrCtx where + mempty = ErrCtx Nothing Nothing [] + +-- === misc util stuff === + +leftmostJust :: Maybe a -> Maybe a -> Maybe a +leftmostJust (Just x) _ = Just x +leftmostJust Nothing y = y + +rightmostJust :: Maybe a -> Maybe a -> Maybe a +rightmostJust = flip leftmostJust + +prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann +prettyLines xs = foldMap (\d -> pretty d <> hardline) xs + +highlightRegion :: (Int, Int) -> String -> String +highlightRegion pos@(low, high) s + | low > high || high > length s = error $ "Bad region: \n" + ++ show pos ++ "\n" ++ s + | otherwise = + -- TODO: flag to control line numbers + -- (disabling for now because it makes quine tests tricky) + -- "Line " ++ show (1 + lineNum) ++ "\n" + + allLines !! lineNum ++ "\n" + ++ take start (repeat ' ') ++ take (stop - start) (repeat '^') ++ "\n" + where + allLines = lines s + (lineNum, start, stop) = getPosTriple pos allLines + +getPosTriple :: (Int, Int) -> [String] -> (Int, Int, Int) +getPosTriple (start, stop) lines_ = (lineNum, start - offset, stop') + where + lineLengths = map ((+1) . length) lines_ + lineOffsets = cumsum lineLengths + lineNum = maxLT lineOffsets start + offset = lineOffsets !! lineNum + stop' = min (stop - offset) (lineLengths !! lineNum) + +cumsum :: [Int] -> [Int] +cumsum xs = scanl (+) 0 xs + +maxLT :: Ord a => [a] -> a -> Int +maxLT [] _ = 0 +maxLT (x:xs) n = if n < x then -1 + else 1 + maxLT xs n + diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 35cf14da6..304b341ad 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -37,7 +37,7 @@ import Optimize exportFunctions :: FilePath -> [(String, Atom)] -> Bindings -> IO () exportFunctions objPath funcs env = do let names = fmap fst funcs - unless (length (nub names) == length names) $ liftEitherIO $ + unless (length (nub names) == length names) $ throw CompilerErr "Duplicate export names" modules <- forM funcs $ \(name, funcAtom) -> do let (impModule, _) = prepareFunctionForExport env name funcAtom diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index bf37743c9..55983b109 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1290,7 +1290,7 @@ withDevice device = local (\opts -> opts {curDevice = device }) -- "shouldn't launch a kernel from device/thread code" -- State keeps track of _all_ names used in the program, Reader keeps the type env. -type ImpCheckM a = StateT (Env ()) (ReaderT (Env IType, Device) (Either Err)) a +type ImpCheckM a = StateT (Env ()) (ReaderT (Env IType, Device) Except) a instance Checkable ImpModule where -- TODO: check main function defined @@ -1479,14 +1479,14 @@ impInstrTypes instr = case instr of IQueryParallelism _ _ -> [IIdxRepTy, IIdxRepTy] ICall (_:>IFunType _ _ resultTys) _ -> resultTys -checkImpBinOp :: MonadError Err m => BinOp -> IType -> IType -> m IType +checkImpBinOp :: Fallible m => BinOp -> IType -> IType -> m IType checkImpBinOp op x y = do retTy <- checkBinOp op (BaseTy x) (BaseTy y) case retTy of BaseTy bt -> return bt _ -> throw CompilerErr $ "Unexpected BinOp return type: " ++ pprint retTy -checkImpUnOp :: MonadError Err m => UnOp -> IType -> m IType +checkImpUnOp :: Fallible m => UnOp -> IType -> m IType checkImpUnOp op x = do retTy <- checkUnOp op (BaseTy x) case retTy of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 2ea44318a..3f9bc8b14 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -13,7 +13,6 @@ module Inference (inferModule, synthModule) where import Control.Applicative import Control.Monad import Control.Monad.Reader -import Control.Monad.Except hiding (Except) import Data.Maybe (fromJust) import Data.Foldable (fold, toList) import Data.Functor @@ -33,12 +32,13 @@ import Type import PPrint import Cat import Util +import Err data UInferEnv = UInferEnv { inferSubst :: SubstEnv - , srcCtx :: SrcCtx + , srcCtx :: SrcPosCtx } -type UInferM = ReaderT UInferEnv (BuilderT (SolverT (Either Err))) +type UInferM = ReaderT UInferEnv (BuilderT (SolverT Except)) type SigmaType = Type -- may start with an implicit lambda type RhoType = Type -- doesn't start with an implicit lambda @@ -736,17 +736,17 @@ checkAllowedUnconditionally eff = do eff' <- zonk eff effAllowed <- getAllowedEffects >>= zonk return $ case checkExtends effAllowed eff' of - Left _ -> False - Right () -> True + Failure _ -> False + Success () -> True openEffectRow :: EffectRow -> UInferM EffectRow openEffectRow (EffectRow effs Nothing) = extendEffRow effs <$> freshEff openEffectRow effRow = return effRow -addSrcContext' :: SrcCtx -> UInferM a -> UInferM a +addSrcContext' :: SrcPosCtx -> UInferM a -> UInferM a addSrcContext' pos = addSrcContext pos . local (\e -> e { srcCtx = pos }) -getSrcCtx :: UInferM SrcCtx +getSrcCtx :: UInferM SrcPosCtx getSrcCtx = asks srcCtx getInferSubst :: UInferM SubstEnv @@ -759,7 +759,7 @@ extInferSubst ext = local (\e -> e { inferSubst = inferSubst e <> ext }) -- We have two variants here because at the top level we want error messages and -- internally we want to consider all alternatives. -type SynthPassM = SubstBuilderT (Either Err) +type SynthPassM = SubstBuilderT Except type SynthDictM = SubstBuilderT [] synthModule :: Scope -> SynthCandidates -> Module -> Except Module @@ -772,7 +772,7 @@ synthModule scope scs (Module Typed decls result) = do return $ Module Core decls' result' synthModule _ _ _ = error $ "Unexpected IR variant" -synthDictTop :: SrcCtx -> Type -> SynthPassM Atom +synthDictTop :: SrcPosCtx -> Type -> SynthPassM Atom synthDictTop ctx ty = do scope <- getScope scs <- getSynthCandidates @@ -784,7 +784,7 @@ synthDictTop ctx ty = do ++ "\n" ++ pprint solutions traverseHoles :: (MonadReader SubstEnv m, MonadBuilder m) - => (SrcCtx -> Type -> m Atom) -> TraversalDef m + => (SrcPosCtx -> Type -> m Atom) -> TraversalDef m traverseHoles fillHole = (traverseDecl recur, traverseExpr recur, synthPassAtom) where synthPassAtom atom = case atom of @@ -829,8 +829,8 @@ inferToSynth m = do scope <- getScope scs <- getSynthCandidates case runUInferM mempty scope scs m of - Left _ -> empty - Right (x, (_, decls)) -> do + Failure _ -> empty + Success (x, (_, decls)) -> do mapM_ emitDecl decls return x @@ -853,7 +853,7 @@ data SolverEnv = SolverEnv { solverVars :: Env Kind , solverSub :: Env (SubstVal Type) } type SolverT m = CatT SolverEnv m -runSolverT :: (MonadError Err m, HasVars a, Subst a, Pretty a) +runSolverT :: (Fallible m, HasVars a, Subst a, Pretty a) => CatT SolverEnv m a -> m a runSolverT m = liftM fst $ flip runCatT mempty $ do ans <- m >>= zonk @@ -900,21 +900,21 @@ checkLeaks tvs m = do unsolved :: SolverEnv -> Env Kind unsolved (SolverEnv vs sub) = vs `envDiff` sub -freshInferenceName :: (MonadError Err m, MonadCat SolverEnv m) => Kind -> m Name +freshInferenceName :: (Fallible m, MonadCat SolverEnv m) => Kind -> m Name freshInferenceName k = do env <- look let v = genFresh (rawName InferenceName "?") $ solverVars env extend $ SolverEnv (v@>k) mempty return v -freshType :: (MonadError Err m, MonadCat SolverEnv m) => Kind -> m Type +freshType :: (Fallible m, MonadCat SolverEnv m) => Kind -> m Type freshType EffKind = Eff <$> freshEff freshType k = Var . (:>k) <$> freshInferenceName k -freshEff :: (MonadError Err m, MonadCat SolverEnv m) => m EffectRow +freshEff :: (Fallible m, MonadCat SolverEnv m) => m EffectRow freshEff = EffectRow mempty . Just <$> freshInferenceName EffKind -constrainEq :: (MonadCat SolverEnv m, MonadError Err m) +constrainEq :: (MonadCat SolverEnv m, Fallible m) => Type -> Type -> m () constrainEq t1 t2 = do t1' <- zonk t1 @@ -931,7 +931,7 @@ zonk x = do s <- looks solverSub return $ scopelessSubst s x -unify :: (MonadCat SolverEnv m, MonadError Err m) +unify :: (MonadCat SolverEnv m, Fallible m) => Type -> Type -> m () unify t1 t2 = do t1' <- zonk t1 @@ -961,7 +961,7 @@ unify t1 t2 = do (Eff eff, Eff eff') -> unifyEff eff eff' _ -> throw TypeErr "" -unifyExtLabeledItems :: (MonadCat SolverEnv m, MonadError Err m) +unifyExtLabeledItems :: (MonadCat SolverEnv m, Fallible m) => ExtLabeledItems Type Name -> ExtLabeledItems Type Name -> m () unifyExtLabeledItems r1 r2 = do r1' <- zonk r1 @@ -987,8 +987,7 @@ unifyExtLabeledItems r1 r2 = do unifyExtLabeledItems (Ext NoLabeledItems t2) (Ext (LabeledItems extras1) (Just newTail)) -unifyEff :: (MonadCat SolverEnv m, MonadError Err m) - => EffectRow -> EffectRow -> m () +unifyEff :: (MonadCat SolverEnv m, Fallible m) => EffectRow -> EffectRow -> m () unifyEff r1 r2 = do r1' <- zonk r1 r2' <- zonk r2 @@ -1005,7 +1004,7 @@ unifyEff r1 r2 = do unifyEff (extendEffRow extras1 newRow) (EffectRow mempty t2) _ -> throw TypeErr "" -bindQ :: (MonadCat SolverEnv m, MonadError Err m) => Var -> Type -> m () +bindQ :: (MonadCat SolverEnv m, Fallible m) => Var -> Type -> m () bindQ v t | v `occursIn` t = throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) | hasSkolems t = throw TypeErr "Can't unify with skolem vars" | otherwise = extend $ mempty { solverSub = v@>SubstVal t } diff --git a/src/lib/LLVMExec.hs b/src/lib/LLVMExec.hs index 5471b0e45..88384956f 100644 --- a/src/lib/LLVMExec.hs +++ b/src/lib/LLVMExec.hs @@ -97,7 +97,7 @@ compileAndBench shouldSyncCUDA logger ast fname args resultTypes = do let run = do let (CInt fd') = fdFD fd exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" + unless (exitCode == 0) $ throw RuntimeErr "" freeLitVals resultPtr resultTypes let sync = when shouldSyncCUDA $ synchronizeCUDA exampleDuration <- snd <$> measureSeconds (run >> sync) @@ -127,7 +127,7 @@ checkedCallFunPtr fd argsPtr resultPtr fPtr = do (exitCode, duration) <- measureSeconds $ do exitCode <- callFunPtr fPtr fd' argsPtr resultPtr return exitCode - unless (exitCode == 0) $ throwIO $ Err RuntimeErr Nothing "" + unless (exitCode == 0) $ throw RuntimeErr "" return duration compileOneOff :: Logger [Output] -> L.Module -> String -> (DexExecutable -> IO a) -> IO a diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs index de52f1b36..2eb589079 100644 --- a/src/lib/Logging.hs +++ b/src/lib/Logging.hs @@ -20,6 +20,7 @@ import Prelude hiding (log) import System.IO import PPrint +import Err data Logger l = Logger (MVar l) (Maybe Handle) @@ -47,7 +48,8 @@ readLog (Logger log _) = liftIO $ readMVar log -- === monadic interface === newtype LoggerT l m a = LoggerT (ReaderT (Logger l) m a) - deriving (Functor, Applicative, Monad, MonadTrans, MonadIO) + deriving (Functor, Applicative, Monad, MonadTrans, + MonadIO, MonadFail, Fallible) class (Pretty l, Monoid l, Monad m) => MonadLogger l m | m -> l where getLogger :: m (Logger l) @@ -62,4 +64,3 @@ logIO val = do runLoggerT :: Monoid l => Logger l -> LoggerT l m a -> m a runLoggerT l (LoggerT m) = runReaderT m l - diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index ae0e34ad3..f939740db 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -13,7 +13,7 @@ module PPrint (pprint, docAsStr, printLitBlock, PrecedenceLevel(..), DocPrec, PrettyPrec(..), atPrec, toJSONStr, prettyFromPrettyPrec, pAppArg, fromInfix) where -import Data.Aeson hiding (Result, Null, Value) +import Data.Aeson hiding (Result, Null, Value, Success) import GHC.Float import Data.Functor ((<&>)) import Data.Foldable (toList) @@ -529,8 +529,8 @@ instance Pretty SourceBlock where instance Pretty Result where pretty (Result outs r) = vcat (map pretty outs) <> maybeErr - where maybeErr = case r of Left err -> p err - Right () -> mempty + where maybeErr = case r of Failure err -> p err + Success () -> mempty instance Pretty Module where pretty = prettyFromPrettyPrec instance PrettyPrec Module where @@ -783,8 +783,8 @@ printOutput :: Bool -> Output -> String printOutput isatty out = addPrefix (addColor isatty Cyan ">") $ pprint $ out printResult :: Bool -> Except () -> String -printResult _ (Right ()) = "" -printResult isatty (Left err) = addColor isatty Red $ addPrefix ">" $ pprint err +printResult _ (Success ()) = "" +printResult isatty (Failure err) = addColor isatty Red $ addPrefix ">" $ pprint err addPrefix :: String -> String -> String addPrefix prefix str = unlines $ map prefixLine $ lines str @@ -805,8 +805,8 @@ instance ToJSON Result where toJSON (Result outs err) = object (outMaps <> errMaps) where errMaps = case err of - Left e -> ["error" .= String (fromString $ pprint e)] - Right () -> [] + Failure e -> ["error" .= String (fromString $ pprint e)] + Success () -> [] outMaps = flip foldMap outs $ \case BenchResult name compileTime runTime _ -> [ "bench_name" .= toJSON name diff --git a/src/lib/Parser.hs b/src/lib/Parser.hs index 06bb01219..bb71e20d6 100644 --- a/src/lib/Parser.hs +++ b/src/lib/Parser.hs @@ -4,7 +4,7 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Parser (Parser, parseit, parseProg, runTheParser, parseData, +module Parser (Parser, parseit, parseProg, parseData, parseTopDeclRepl, uint, withSource, parseExpr, exprAsModule, emptyLines, brackets, symbol, symChar, keyWordStrs) where @@ -25,6 +25,7 @@ import qualified Text.Megaparsec.Debug import LabeledItems import Syntax import PPrint +import Err -- canPair is used for the ops (,) (|) (&) which should only appear inside -- parentheses (to avoid conflicts with records and other syntax) @@ -49,14 +50,14 @@ parseExpr :: String -> Except UExpr parseExpr s = parseit s (expr <* eof) parseit :: String -> Parser a -> Except a -parseit s p = case runTheParser s (p <* (optional eol >> eof)) of - Left e -> throw ParseErr (errorBundlePretty e) +parseit s p = case parse (runReaderT p (ParseCtx 0 False False)) "" s of + Left e -> throw ParseErr $ errorBundlePretty e Right x -> return x mustParseit :: String -> Parser a -> a mustParseit s p = case parseit s p of - Right ans -> ans - Left e -> error $ "This shouldn't happen:\n" ++ pprint e + Failure e -> error $ "This shouldn't happen:\n" ++ pprint e + Success x -> x importModule :: Parser SourceBlock' importModule = ImportModule <$> do @@ -1021,9 +1022,6 @@ symChars = ".,!$^&*:-~+/=<>|?\\@" -- === Util === -runTheParser :: String -> Parser a -> Either (ParseErrorBundle String Void) a -runTheParser s p = parse (runReaderT p (ParseCtx 0 False False)) "" s - sc :: Parser () sc = L.space space lineComment empty diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 174e78059..53d4a2f61 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -25,6 +25,7 @@ import Syntax import PPrint import Parser import Serialize () +import Err pprintHtml :: ToMarkup a => a -> String pprintHtml x = renderHtml $ toMarkup x @@ -52,8 +53,8 @@ wrapBody blocks = docTypeHtml $ do instance ToMarkup Result where toMarkup (Result outs err) = foldMap toMarkup outs <> err' where err' = case err of - Left e -> cdiv "err-block" $ toHtml $ pprint e - Right () -> mempty + Failure e -> cdiv "err-block" $ toHtml $ pprint e + Success () -> mempty instance ToMarkup Output where toMarkup out = case out of @@ -75,10 +76,7 @@ cdiv c inner = H.div inner ! class_ (stringValue c) highlightSyntax :: String -> Html highlightSyntax s = foldMap (uncurry syntaxSpan) classified - where - classified = case runTheParser s (many (withSource classify) <* eof) of - Left e -> error $ errorBundlePretty e - Right ans -> ans + where classified = ignoreExcept $ parseit s (many (withSource classify) <* eof) syntaxSpan :: String -> StrClass -> Html syntaxSpan s NormalStr = toHtml s diff --git a/src/lib/SaferNames/Builder.hs b/src/lib/SaferNames/Builder.hs index 7e99535a8..163b106ab 100644 --- a/src/lib/SaferNames/Builder.hs +++ b/src/lib/SaferNames/Builder.hs @@ -350,7 +350,7 @@ buildPureNaryLam arr (EmptyAbs (Nest (b:>ty) rest)) cont = do cont (x':xs) buildPureNaryLam _ _ _ = error "impossible" -buildPi :: (MonadErr1 m, Builder m) +buildPi :: (Fallible1 m, Builder m) => Arrow -> Type n -> (forall l. Ext n l => AtomName l -> m l (EffectRow l, Type l)) -> m n (Type n) @@ -362,7 +362,7 @@ buildPi arr ty body = do Abs b (PairE effs resultTy) <- return ab return $ Pi $ PiType arr b effs resultTy -buildNonDepPi :: (MonadErr1 m, Builder m) +buildNonDepPi :: (Fallible1 m, Builder m) => Arrow -> Type n -> EffectRow n -> Type n -> m n (Type n) buildNonDepPi arr argTy effs resultTy = buildPi arr argTy \_ -> do resultTy' <- injectM resultTy diff --git a/src/lib/SaferNames/Inference.hs b/src/lib/SaferNames/Inference.hs index c77f3cc0c..9ce701c0a 100644 --- a/src/lib/SaferNames/Inference.hs +++ b/src/lib/SaferNames/Inference.hs @@ -12,7 +12,6 @@ import Prelude hiding ((.), id) import Control.Category import Control.Applicative import Control.Monad -import Control.Monad.Except hiding (Except) import Data.Foldable (toList) import Data.List (sortOn) import Data.String (fromString) @@ -55,7 +54,7 @@ isTopDecl decl = case decl of -- === Inferer monad === -class (MonadErr2 m, Builder2 m, EnvReader Name m) +class (MonadFail2 m, Fallible2 m, Builder2 m, EnvReader Name m) => Inferer (m::MonadKind2) data InfererM (i::S) (o::S) (a:: *) @@ -79,9 +78,9 @@ instance Monad (InfererM i o) where instance MonadFail (InfererM i o) where fail = undefined -instance MonadError Err (InfererM i o) where - throwError = undefined - catchError = undefined +instance Fallible (InfererM i o) where + throwErrs = undefined + addErrCtx = undefined instance Builder (InfererM i) where buildScoped _ = undefined @@ -128,10 +127,10 @@ typeReduceAtom = undefined tryGetType :: (Inferer m, HasType e) => e o -> m i o (Type o) tryGetType = undefined -getSrcCtx :: Inferer m => m i o SrcCtx +getSrcCtx :: Inferer m => m i o SrcPosCtx getSrcCtx = undefined -addSrcContext' :: Inferer m => SrcCtx -> m i o a -> m i o a +addSrcContext' :: Inferer m => SrcPosCtx -> m i o a -> m i o a addSrcContext' = undefined makeReqCon :: Inferer m => Type o -> m i o SuggestionStrength @@ -377,7 +376,7 @@ buildMonomorphicCase alts scrut resultTy = do resultTy' <- injectM resultTy buildNthOrderedAlt alts' scrutTy' resultTy' i vs -buildSortedCase :: (MonadErr1 m, Builder m, Emits n) +buildSortedCase :: (Fallible1 m, Builder m, Emits n) => Atom n -> [IndexedAlt n] -> Type n -> m n (Atom n) buildSortedCase scrut alts resultTy = do @@ -417,7 +416,7 @@ buildSortedCase scrut alts resultTy = do -- Make sure all of the alternatives are exclusive with the tail pattern (could -- technically allow overlap but this is simpler). Split based on the tail -- pattern's skipped types. -checkNoTailOverlaps :: MonadErr1 m => [IndexedAlt n] -> LabeledItems (Type n) -> m n () +checkNoTailOverlaps :: Fallible1 m => [IndexedAlt n] -> LabeledItems (Type n) -> m n () checkNoTailOverlaps alts (LabeledItems tys) = do forM_ alts \(IndexedAlt (VariantAlt label i) _) -> case M.lookup label tys of @@ -878,8 +877,8 @@ checkAllowedUnconditionally eff = do eff' <- zonk eff effAllowed <- getAllowedEffects >>= zonk return $ case checkExtends effAllowed eff' of - Left _ -> False - Right () -> True + Failure _ -> False + Success () -> True openEffectRow :: Inferer m => EffectRow o -> m i o (EffectRow o) openEffectRow (EffectRow effs Nothing) = extendEffRow effs <$> freshEff diff --git a/src/lib/SaferNames/Name.hs b/src/lib/SaferNames/Name.hs index fdcef520a..8970b3d40 100644 --- a/src/lib/SaferNames/Name.hs +++ b/src/lib/SaferNames/Name.hs @@ -35,7 +35,7 @@ module SaferNames.Name ( runScopeReaderT, runEnvReaderT, ScopeReaderT (..), EnvReaderT (..), lookupEnvM, dropSubst, extendEnv, MonadKind, MonadKind1, MonadKind2, - Monad1, Monad2, MonadErr1, MonadErr2, MonadFail1, MonadFail2, + Monad1, Monad2, Fallible1, Fallible2, MonadFail1, MonadFail2, ScopeReader2, ScopeExtender2, applyAbs, applyNaryAbs, ZipEnvReader (..), alphaEqTraversable, checkAlphaEq, AlphaEq, AlphaEqE (..), AlphaEqB (..), AlphaEqV, ConstE (..), @@ -522,8 +522,8 @@ type MonadKind2 = S -> S -> * -> * type Monad1 (m :: MonadKind1) = forall (n::S) . Monad (m n ) type Monad2 (m :: MonadKind2) = forall (n::S) (l::S) . Monad (m n l) -type MonadErr1 (m :: MonadKind1) = forall (n::S) . MonadErr (m n ) -type MonadErr2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadErr (m n l) +type Fallible1 (m :: MonadKind1) = forall (n::S) . Fallible (m n ) +type Fallible2 (m :: MonadKind2) = forall (n::S) (l::S) . Fallible (m n l) type MonadFail1 (m :: MonadKind1) = forall (n::S) . MonadFail (m n ) type MonadFail2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadFail (m n l) @@ -585,7 +585,7 @@ type AlphaEq e = AlphaEqE e :: Constraint -- TODO: consider generalizing this to something that can also handle e.g. -- unification and type checking with some light reduction class ( forall i1 i2 o. Monad (m i1 i2 o) - , forall i1 i2 o. MonadErr (m i1 i2 o) + , forall i1 i2 o. Fallible (m i1 i2 o) , forall i1 i2 o. MonadFail (m i1 i2 o) , forall i1 i2. ScopeGetter (m i1 i2) , forall i1 i2. ScopeExtender (m i1 i2)) @@ -620,11 +620,11 @@ class ( InjectableV v , forall c. NameColor c => AlphaEqE (v c)) => AlphaEqV (v::V) where -checkAlphaEq :: (AlphaEqE e, MonadErr1 m, ScopeReader m) +checkAlphaEq :: (AlphaEqE e, Fallible1 m, ScopeReader m) => e n -> e n -> m n () checkAlphaEq e1 e2 = do WithScope scope (PairE e1' e2') <- addScope $ PairE e1 e2 - liftEither $ + liftExcept $ runScopeReaderT scope $ flip runReaderT (emptyNameFunction, emptyNameFunction) $ runZipEnvReaderT $ withEmptyZipEnv $ alphaEqE e1' e2' @@ -687,7 +687,7 @@ instance (AlphaEqE e1, AlphaEqE e2) => AlphaEqE (EitherE e1 e2) where newtype ScopeReaderT (m::MonadKind) (n::S) (a:: *) = ScopeReaderT {runScopeReaderT' :: ReaderT (DistinctEvidence n, Scope n) m a} - deriving (Functor, Applicative, Monad, MonadError err, MonadFail) + deriving (Functor, Applicative, Monad, MonadFail, Fallible) runScopeReaderT :: Distinct n => Scope n -> ScopeReaderT m n a -> m a runScopeReaderT scope m = @@ -714,7 +714,7 @@ instance Monad m => ScopeExtender (ScopeReaderT m) where newtype EnvReaderT (v::V) (m::MonadKind1) (i::S) (o::S) (a:: *) = EnvReaderT { runEnvReaderT' :: ReaderT (NameFunction v i o) (m o) a } - deriving (Functor, Applicative, Monad, MonadError err, MonadFail) + deriving (Functor, Applicative, Monad, MonadFail, Fallible) type ScopedEnvReader (v::V) = EnvReaderT v (ScopeReaderT Identity) :: MonadKind2 @@ -757,7 +757,7 @@ class OutReader (e::E) (m::MonadKind1) | m -> e where newtype OutReaderT (e::E) (m::MonadKind1) (n::S) (a :: *) = OutReaderT { runOutReaderT' :: ReaderT (e n) (m n) a } - deriving (Functor, Applicative, Monad, MonadError err, MonadFail) + deriving (Functor, Applicative, Monad, MonadFail, Fallible) runOutReaderT :: e n -> OutReaderT e m n a -> m n a runOutReaderT env m = flip runReaderT env $ runOutReaderT' m @@ -792,7 +792,7 @@ instance OutReader e m => OutReader e (EnvReaderT v m i) where newtype ZipEnvReaderT (m::MonadKind1) (i1::S) (i2::S) (o::S) (a:: *) = ZipEnvReaderT { runZipEnvReaderT :: ReaderT (ZipEnv i1 i2 o) (m o) a } - deriving (Functor, Applicative, Monad, MonadError err, MonadFail) + deriving (Functor, Applicative, Monad, Fallible, MonadFail) type ZipEnv i1 i2 o = (NameFunction Name i1 o, NameFunction Name i2 o) @@ -813,7 +813,7 @@ instance (ScopeReader m, ScopeExtender m) env2' <- injectM env2 cont (env1', env2') -instance (Monad1 m, ScopeReader m, ScopeExtender m, MonadErr1 m, MonadFail1 m) +instance (Monad1 m, ScopeReader m, ScopeExtender m, Fallible1 m) => ZipEnvReader (ZipEnvReaderT m) where lookupZipEnvFst v = ZipEnvReaderT $ (!v) <$> fst <$> ask diff --git a/src/lib/SaferNames/Parser.hs b/src/lib/SaferNames/Parser.hs index af7d5787b..51dc321cc 100644 --- a/src/lib/SaferNames/Parser.hs +++ b/src/lib/SaferNames/Parser.hs @@ -4,7 +4,7 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module SaferNames.Parser (Parser, parseit, parseProg, runTheParser, parseData, +module SaferNames.Parser (Parser, parseit, parseProg, parseData, parseTopDeclRepl, uint, withSource, parseExpr, exprAsModule, emptyLines, brackets, symbol, symChar, keyWordStrs) where @@ -51,14 +51,14 @@ parseExpr :: String -> Except (UExpr VoidS) parseExpr s = parseit s (expr <* eof) parseit :: String -> Parser a -> Except a -parseit s p = case runTheParser s (p <* (optional eol >> eof)) of - Left e -> throw ParseErr (errorBundlePretty e) +parseit s p = case parse (runReaderT p (ParseCtx 0 False False)) "" s of + Left e -> throw ParseErr $ errorBundlePretty e Right x -> return x mustParseit :: String -> Parser a -> a mustParseit s p = case parseit s p of - Right ans -> ans - Left e -> error $ "This shouldn't happen:\n" ++ pprint e + Success x -> x + Failure e -> error $ "This shouldn't happen:\n" ++ pprint e importModule :: Parser SourceBlock' importModule = ImportModule <$> do @@ -1083,9 +1083,6 @@ symChars = ".,!$^&*:-~+/=<>|?\\@" -- === Util === -runTheParser :: String -> Parser a -> Either (ParseErrorBundle String Void) a -runTheParser s p = parse (runReaderT p (ParseCtx 0 False False)) "" s - sc :: Parser () sc = L.space space lineComment empty diff --git a/src/lib/SaferNames/SourceRename.hs b/src/lib/SaferNames/SourceRename.hs index ef2336d30..25e2bb8ee 100644 --- a/src/lib/SaferNames/SourceRename.hs +++ b/src/lib/SaferNames/SourceRename.hs @@ -23,7 +23,7 @@ import SaferNames.Name import SaferNames.ResolveImplicitNames import SaferNames.Syntax -renameSourceNames :: MonadErr m => Scope (n::S) -> SourceMap n -> SourceUModule -> m (UModule n) +renameSourceNames :: Fallible m => Scope (n::S) -> SourceMap n -> SourceUModule -> m (UModule n) -- renameSourceNames scope sourceMap m = -- runReaderT (runReaderT (renameSourceNames' m) (scope, sourceMap)) False renameSourceNames = undefined @@ -33,7 +33,7 @@ renameSourceNames = undefined -- We have this class because we want to read some extra context (whether -- shadowing is allowed) but we've already used up the MonadReader -- (we can't add a field because we want it to be monoidal). -class (Monad1 m, ScopeExtender m, MonadErr1 m) => Renamer m where +class (Monad1 m, ScopeExtender m, Fallible1 m) => Renamer m where askMayShadow :: m n Bool setMayShadow :: Bool -> m n a -> m n a askSourceMap :: m n (SourceMap n) @@ -52,11 +52,11 @@ class (Monad1 m, ScopeExtender m, MonadErr1 m) => Renamer m where -- instance ScopeExtender RenamerData where --- instance MonadError Err (RenamerData n) where +-- instance Fallibleor Err (RenamerData n) where -- instance Renamer RenamerData where --- instance MonadErr m => Renamer n (ReaderT (RenameEnv n) (ReaderT Bool m)) where +-- instance Fallible m => Renamer n (ReaderT (RenameEnv n) (ReaderT Bool m)) where -- askMayShadow = lift ask -- setMayShadow mayShadow cont = do -- env <- ask diff --git a/src/lib/SaferNames/Syntax.hs b/src/lib/SaferNames/Syntax.hs index 634565cfb..f381b21d1 100644 --- a/src/lib/SaferNames/Syntax.hs +++ b/src/lib/SaferNames/Syntax.hs @@ -20,7 +20,7 @@ module SaferNames.Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), Except, EffectP (..), Effect, UEffect, RWS (..), EffectRowP (..), EffectRow, UEffectRow, - SrcPos, Binder, Block (..), Decl (..), + Binder, Block (..), Decl (..), Expr (..), Atom (..), Arrow (..), PrimTC (..), Abs (..), PrimExpr (..), PrimCon (..), LitVal (..), PrimEffect (..), PrimOp (..), PrimHof (..), LamExpr (..), PiType (..), LetAnn (..), @@ -54,7 +54,7 @@ module SaferNames.Syntax ( considerNonDepPiType, fromNonDepTabTy, binderType, getProjection, applyIntBinOp, applyIntCmpOp, applyFloatBinOp, applyFloatUnOp, - SrcCtx, freshBinderNamePair, piArgType, piArrow, extendEffRow, + freshBinderNamePair, piArgType, piArrow, extendEffRow, pattern IdxRepTy, pattern IdxRepVal, pattern TagRepTy, pattern TagRepVal, pattern Word8Ty, pattern UnitTy, pattern PairTy, @@ -274,7 +274,7 @@ instance (InjectableE e, ScopeReader m, BindingsExtender m) newtype BindingsReaderT (m::MonadKind) (n::S) (a:: *) = BindingsReaderT {runBindingsReaderT' :: ReaderT (Bindings n) (ScopeReaderT m n) a } - deriving (Functor, Applicative, Monad, MonadError err, MonadFail) + deriving (Functor, Applicative, Monad, MonadFail, Fallible) runBindingsReaderT :: Distinct n => (Scope n, Bindings n) -> (BindingsReaderT m n a) -> m a runBindingsReaderT (scope, bindings) cont = @@ -512,14 +512,14 @@ data UPat' (n::S) (l::S) = | UPatTable (Nest UPat n l) deriving (Show) -data WithSrcE (a::E) (n::S) = WithSrcE SrcCtx (a n) +data WithSrcE (a::E) (n::S) = WithSrcE SrcPosCtx (a n) deriving (Show) -data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcCtx (binder n l) +data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcPosCtx (binder n l) deriving (Show) class HasSrcPos a where - srcPos :: a -> SrcCtx + srcPos :: a -> SrcPosCtx instance HasSrcPos (WithSrcE (a::E) (n::S)) where srcPos (WithSrcE pos _) = pos @@ -825,14 +825,14 @@ mkConsListTy = foldr PairTy UnitTy mkConsList :: [Atom n] -> Atom n mkConsList = foldr PairVal UnitVal -fromConsListTy :: MonadError Err m => Type n -> m [Type n] +fromConsListTy :: Fallible m => Type n -> m [Type n] fromConsListTy ty = case ty of UnitTy -> return [] PairTy t rest -> (t:) <$> fromConsListTy rest _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty -- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) -fromLeftLeaningConsListTy :: MonadError Err m => Int -> Type n -> m (Type n, [Type n]) +fromLeftLeaningConsListTy :: Fallible m => Int -> Type n -> m (Type n, [Type n]) fromLeftLeaningConsListTy depth initTy = go depth initTy [] where go 0 ty xs = return (ty, reverse xs) @@ -840,7 +840,7 @@ fromLeftLeaningConsListTy depth initTy = go depth initTy [] PairTy lt rt -> go (remDepth - 1) lt (rt : xs) _ -> throw CompilerErr $ "Not a pair: " ++ show xs -fromConsList :: MonadError Err m => Atom n -> m [Atom n] +fromConsList :: Fallible m => Atom n -> m [Atom n] fromConsList xs = case xs of UnitVal -> return [] PairVal x rest -> (x:) <$> fromConsList rest diff --git a/src/lib/SaferNames/Type.hs b/src/lib/SaferNames/Type.hs index 22ab277a4..4d9c46144 100644 --- a/src/lib/SaferNames/Type.hs +++ b/src/lib/SaferNames/Type.hs @@ -21,14 +21,11 @@ import Prelude hiding (id) import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader -import Control.Monad.Except hiding (Except) -import Control.Monad.Identity import Data.Foldable (toList) import Data.Functor import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import qualified Data.Set as S -import Data.Text.Prettyprint.Doc hiding (nest) import LabeledItems @@ -42,25 +39,25 @@ import SaferNames.PPrint () -- === top-level API === -checkModule :: (Distinct n, MonadErr m) => TopBindings n -> Module n -> m () +checkModule :: (Distinct n, Fallible m) => TopBindings n -> Module n -> m () checkModule env m = addContext ("Checking module:\n" ++ pprint m) $ asCompilerErr $ runBindingsReaderT (fromTopBindings env) $ checkTypes m -checkTypes :: (BindingsReader m, MonadErr1 m, CheckableE e) +checkTypes :: (BindingsReader m, Fallible1 m, CheckableE e) => e n -> m n () checkTypes e = do Distinct <- getDistinct WithBindings bindings scope e' <- addBindings e - liftEither $ runTyperT (scope, bindings) $ void $ checkE e' + liftExcept $ runTyperT (scope, bindings) $ void $ checkE e' getType :: (BindingsReader m, HasType e) => e n -> m n (Type n) getType e = do Distinct <- getDistinct WithBindings bindings scope e' <- addBindings e - injectM $ runIgnoreChecks $ runTyperT (scope, bindings) $ getTypeE e' + injectM $ runHardFail $ runTyperT (scope, bindings) $ getTypeE e' instantiatePi :: ScopeReader m => PiType n -> Atom n -> m n (EffectRow n, Atom n) instantiatePi (PiType _ b eff body) x = do @@ -70,43 +67,26 @@ instantiatePi (PiType _ b eff body) x = do -- === the type checking/querying monad === -- TODO: not clear why we need the explicit `Monad2` here since it should --- already be a superclass, transitively, through both MonadErr2 and +-- already be a superclass, transitively, through both Fallible2 and -- MonadAtomSubst. -class ( MonadFail2 m, Monad2 m, MonadErr2 m, EnvReader Name m +class ( Monad2 m, Fallible2 m, EnvReader Name m , BindingsGetter2 m, BindingsExtender2 m) => Typer (m::MonadKind2) where declareEffs :: EffectRow o -> m i o () extendAllowedEffect :: Effect o -> m i o () -> m i o () withAllowedEff :: EffectRow o -> m i o a -> m i o a --- This fakes MonadErr by just throwing a hard error using `error`. We use it --- to skip the checks (via laziness) when we just querying types. -newtype IgnoreChecks e a = IgnoreChecks { runIgnoreChecks' :: Identity a } - deriving (Functor, Applicative, Monad) - -runIgnoreChecks :: IgnoreChecks e a -> a -runIgnoreChecks = runIdentity . runIgnoreChecks' - -instance MonadFail (IgnoreChecks e) where - fail = error "Monad fail!" - -instance Pretty e => MonadError e (IgnoreChecks e) where - throwError e = error $ pprint e - catchError m _ = m - newtype TyperT (m::MonadKind) (i::S) (o::S) (a :: *) = TyperT { runTyperT' :: EnvReaderT Name (OutReaderT EffectRow (BindingsReaderT m)) i o a } - deriving ( Functor, Applicative, Monad, MonadFail + deriving ( Functor, Applicative, Monad , EnvReader Name , ScopeReader, ScopeGetter, BindingsReader , BindingsGetter, BindingsExtender) -deriving instance MonadError e m => MonadError e (TyperT m i o) - -runTyperT :: (MonadErr m, Distinct n) +runTyperT :: (Fallible m, Distinct n) => ScopedBindings n -> TyperT m n n a -> m a runTyperT scope m = do runBindingsReaderT scope $ @@ -114,7 +94,14 @@ runTyperT scope m = do runEnvReaderT idNameFunction $ runTyperT' m -instance (MonadFail m, MonadErr m) => Typer (TyperT m) where +instance Fallible m => MonadFail (TyperT m i o) where + fail = undefined + +instance Fallible m => Fallible (TyperT m i o) where + throwErrs = undefined + addErrCtx = undefined + +instance Fallible m => Typer (TyperT m) where declareEffs eff = TyperT do allowedEffs <- askOutReader checkExtends allowedEffs eff @@ -683,7 +670,7 @@ checkRWSAction rws f = do -- Having this as a separate helper function helps with "'b0' is untouchable" errors -- from GADT+monad type inference. -checkEmpty :: MonadErr m => Nest b n l -> m () +checkEmpty :: Fallible m => Nest b n l -> m () checkEmpty Empty = return () checkEmpty _ = throw TypeErr "Not empty" @@ -795,12 +782,12 @@ typeCheckRef x = do TC (RefType _ a) <- getTypeE x return a -checkArrowAndEffects :: MonadErr m => Arrow -> EffectRow n -> m () +checkArrowAndEffects :: Fallible m => Arrow -> EffectRow n -> m () checkArrowAndEffects PlainArrow _ = return () checkArrowAndEffects _ Pure = return () checkArrowAndEffects _ _ = throw TypeErr $ "Only plain arrows may have effects" -checkIntBaseType :: MonadError Err m => Bool -> BaseType -> m () +checkIntBaseType :: Fallible m => Bool -> BaseType -> m () checkIntBaseType allowVector t = case t of Scalar sbt -> checkSBT sbt Vector sbt | allowVector -> checkSBT sbt @@ -816,7 +803,7 @@ checkIntBaseType allowVector t = case t of notInt = throw TypeErr $ "Expected a fixed-width " ++ (if allowVector then "" else "scalar ") ++ "integer type, but found: " ++ pprint t -checkFloatBaseType :: MonadError Err m => Bool -> BaseType -> m () +checkFloatBaseType :: Fallible m => Bool -> BaseType -> m () checkFloatBaseType allowVector t = case t of Scalar sbt -> checkSBT sbt Vector sbt | allowVector -> checkSBT sbt @@ -829,7 +816,7 @@ checkFloatBaseType allowVector t = case t of notFloat = throw TypeErr $ "Expected a fixed-width " ++ (if allowVector then "" else "scalar ") ++ "floating-point type, but found: " ++ pprint t -checkValidCast :: MonadErr m => BaseType -> BaseType -> m () +checkValidCast :: Fallible m => BaseType -> BaseType -> m () checkValidCast (PtrType _) (PtrType _) = return () checkValidCast (PtrType _) (Scalar Int64Type) = return () checkValidCast (Scalar Int64Type) (PtrType _) = return () @@ -853,14 +840,14 @@ typeCheckBaseType e = data ArgumentType = SomeFloatArg | SomeIntArg | SomeUIntArg data ReturnType = SameReturn | Word8Return -checkOpArgType :: MonadError Err m => ArgumentType -> BaseType -> m () +checkOpArgType :: Fallible m => ArgumentType -> BaseType -> m () checkOpArgType argTy x = case argTy of SomeIntArg -> checkIntBaseType True x SomeUIntArg -> assertEq x (Scalar Word8Type) "" SomeFloatArg -> checkFloatBaseType True x -checkBinOp :: MonadError Err m => BinOp -> BaseType -> BaseType -> m BaseType +checkBinOp :: Fallible m => BinOp -> BaseType -> BaseType -> m BaseType checkBinOp op x y = do checkOpArgType argTy x assertEq x y "" @@ -884,7 +871,7 @@ checkBinOp op x y = do ia = SomeIntArg; fa = SomeFloatArg br = Word8Return; sr = SameReturn -checkUnOp :: MonadError Err m => UnOp -> BaseType -> m BaseType +checkUnOp :: Fallible m => UnOp -> BaseType -> m BaseType checkUnOp op x = do checkOpArgType argTy x return $ case retTy of @@ -949,7 +936,7 @@ checkEffRow effRow@(EffectRow effs effTail) = do declareEff :: Typer m => Effect o -> m i o () declareEff eff = declareEffs $ oneEffect eff -checkExtends :: MonadError Err m => EffectRow n -> EffectRow n -> m () +checkExtends :: Fallible m => EffectRow n -> EffectRow n -> m () checkExtends allowed (EffectRow effs effTail) = do let (EffectRow allowedEffs allowedEffTail) = allowed case effTail of diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 365ffa5d4..bc7aaa95a 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -14,7 +14,6 @@ import Data.List (nub) import Data.String import Control.Monad.Writer import Control.Monad.Reader -import Control.Monad.Except hiding (Except) import qualified Data.Set as S import qualified Data.Map.Strict as M @@ -23,7 +22,7 @@ import Err import LabeledItems import Syntax -renameSourceNames :: MonadErr m => Scope -> SourceMap -> SourceUModule -> m UModule +renameSourceNames :: Fallible m => Scope -> SourceMap -> SourceUModule -> m UModule renameSourceNames scope sourceMap m = runReaderT (runReaderT (renameSourceNames' m) (scope, sourceMap)) False @@ -32,11 +31,11 @@ type RenameEnv = (Scope, SourceMap) -- We have this class because we want to read some extra context (whether -- shadowing is allowed) but we've already used up the MonadReader -- (we can't add a field because we want it to be monoidal). -class (MonadReader RenameEnv m, MonadErr m) => Renamer m where +class (MonadReader RenameEnv m, Fallible m) => Renamer m where askMayShadow :: m Bool setMayShadow :: Bool -> m a -> m a -instance MonadErr m => Renamer (ReaderT RenameEnv (ReaderT Bool m)) where +instance Fallible m => Renamer (ReaderT RenameEnv (ReaderT Bool m)) where askMayShadow = lift ask setMayShadow mayShadow cont = do env <- ask @@ -323,11 +322,15 @@ instance (Monoid env, MonadReader env m) => Monad (WithEnv env m) where instance Monoid env => MonadTrans (WithEnv env) where lift m = WithEnv $ fmap (,mempty) m -instance (Monoid env, MonadError e m, MonadReader env m) - => MonadError e (WithEnv env m) where - throwError e = lift $ throwError e - catchError (WithEnv m) handler = - WithEnv $ catchError m (runWithEnv . handler) +instance (Monoid env, MonadFail m, MonadReader env m) + => MonadFail (WithEnv env m) where + fail s = lift $ fail s + +instance (Monoid env, Fallible m, MonadReader env m) + => Fallible (WithEnv env m) where + throwErrs errs = lift $ throwErrs errs + addErrCtx ctx (WithEnv m) = WithEnv $ + addErrCtx ctx m -- === Traversal to find implicit names diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 4a1af7b8b..f4bebba43 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -16,7 +16,7 @@ module Syntax ( Type, Kind, BaseType (..), ScalarBaseType (..), EffectP (..), Effect, RWS (..), EffectRowP (..), EffectRow, - ClassName (..), TyQual (..), SrcPos, Var, Binder, Block (..), Decl (..), + ClassName (..), TyQual (..), Var, Binder, Block (..), Decl (..), Expr (..), Atom (..), ArrowP (..), Arrow, PrimTC (..), Abs (..), PrimExpr (..), PrimCon (..), LitVal (..), PrimEffect (..), PrimOp (..), PrimHof (..), LamExpr, PiType, WithSrc (..), srcPos, LetAnn (..), @@ -33,9 +33,9 @@ module Syntax ( UAlt (..), AltP, Alt, ModuleName, IScope, BinderInfo (..), AnyBinderInfo (..), AsRecEnv (..), Bindings, CUDAKernel (..), BenchStats, - SrcCtx, Result (..), Output (..), OutFormat (..), - Err (..), ErrType (..), Except, throw, throwIf, modifyErr, addContext, - addSrcContext, catchIOExcept, liftEitherIO, (-->), (--@), (==>), + Result (..), Output (..), OutFormat (..), + Err (..), ErrType (..), Except, throw, throwIf, addContext, + addSrcContext, catchIOExcept, liftExcept, (-->), (--@), (==>), boundUVars, PassName (..), boundVars, renamingSubst, bindingsAsVars, freeVars, freeUVars, Subst, HasVars, BindsVars, Ptr, PtrType, AddressSpace (..), showPrimName, strToPrimName, primNameToStr, @@ -325,13 +325,13 @@ data UPat' = UPatBinder UBinder | UPatTable [UPat] deriving (Show) -data WithSrc a = WithSrc SrcCtx a +data WithSrc a = WithSrc SrcPosCtx a deriving (Show, Functor, Foldable, Traversable) pattern UPatIgnore :: UPat' pattern UPatIgnore = UPatBinder UIgnore -srcPos :: WithSrc a -> SrcCtx +srcPos :: WithSrc a -> SrcPosCtx srcPos (WithSrc pos _) = pos -- === primitive constructors and operators === @@ -361,7 +361,7 @@ data PrimCon e = Lit LitVal | ProdCon [e] | SumCon e Int e -- type, tag, payload - | ClassDictHole SrcCtx e -- Only used during type inference + | ClassDictHole SrcPosCtx e -- Only used during type inference | SumAsProd e e [[e]] -- type, tag, payload (only used during Imp lowering) -- These are just newtype wrappers. TODO: use ADTs instead | IntRangeVal e e e @@ -1662,7 +1662,7 @@ isTabTy (TabTy _ _) = True isTabTy _ = False -- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) -fromLeftLeaningConsListTy :: MonadError Err m => Int -> Type -> m (Type, [Type]) +fromLeftLeaningConsListTy :: Fallible m => Int -> Type -> m (Type, [Type]) fromLeftLeaningConsListTy depth initTy = go depth initTy [] where go 0 ty xs = return (ty, reverse xs) diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 79e8c89b3..efa0302d0 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -50,7 +50,7 @@ import Logging import LLVMExec import PPrint() import Parser -import Util (highlightRegion, measureSeconds, onFst, onSnd) +import Util (measureSeconds, onFst, onSnd) import Optimize import Parallelize @@ -102,7 +102,7 @@ execInterblockM opts env m = snd <$> runInterblockM opts env m -- === monad for wiring together the passes within each source block === -class ( forall n. MonadErr (m n) +class ( forall n. Fallible (m n) , forall n. MonadLogger [Output] (m n) , forall n. ConfigReader (m n) , forall n. MonadIO (m n) ) @@ -113,17 +113,18 @@ class ( forall n. MonadErr (m n) newtype PassesM (n::S.S) a = PassesM { runPassesM' :: ReaderT (Bool, EvalConfig, (S.DistinctWitness n, JointTopState n)) (LoggerT [Output] IO) a } - deriving (Functor, Applicative, Monad, MonadIO) + deriving (Functor, Applicative, Monad, MonadIO, MonadFail, Fallible) type ModulesImported = M.Map ModuleName ModuleImportStatus data ModuleImportStatus = CurrentlyImporting | FullyImported deriving Generic runPassesM :: S.Distinct n => Bool -> EvalConfig -> JointTopState n -> PassesM n a -> IO (Except a, [Output]) -runPassesM bench opts env (PassesM m) = do +runPassesM bench opts env m = do let maybeLogFile = logFile opts runLogger maybeLogFile \l -> - runExceptT $ catchIOExcept $ runLoggerT l $ runReaderT m $ (bench, opts, (S.Distinct, env)) + catchIOExcept $ runLoggerT l $ runReaderT (runPassesM' m) $ + (bench, opts, (S.Distinct, env)) -- ====== @@ -163,11 +164,11 @@ evalSourceBlock' block = case sbContents block of RunModule m -> do (maybeEvaluatedModule, outs) <- liftPassesM (requiresBench block) $ evalUModule m case maybeEvaluatedModule of - Left err -> return $ Result outs $ Left err - Right evaluatedModule -> do + Failure err -> return $ Result outs $ Failure err + Success evaluatedModule -> do TopStateEx curState <- getTopStateEx setTopStateEx $ extendTopStateD curState evaluatedModule - return $ Result outs $ Right () + return $ Result outs $ Success () Command cmd (v, m) -> liftPassesM_ (requiresBench block) case cmd of EvalExpr fmt -> do val <- evalUModuleVal v m @@ -241,13 +242,13 @@ filterLogs block (Result outs err) = let summarizeModuleResults :: [Result] -> Result summarizeModuleResults results = - case [err | Result _ (Left err) <- results] of - [] -> Result allOuts $ Right () + case [err | Result _ (Failure err) <- results] of + [] -> Result allOuts $ Success () errs -> Result allOuts $ throw ModuleImportErr $ foldMap pprint errs where allOuts = foldMap resultOutputs results emptyResult :: Result -emptyResult = Result [] (Right ()) +emptyResult = Result [] (Success ()) evalFile :: MonadInterblock m => FilePath -> m [(SourceBlock, Result)] evalFile fname = evalSourceText =<< (liftIO $ readFile fname) @@ -297,7 +298,7 @@ evalUModuleVal v m = do AtomBinderInfo _ (LetBound _ (Atom atom)) -> return atom _ -> throw TypeErr $ "Not an atom name: " ++ pprint v -lookupSourceName :: MonadErr m => TopStateEx -> SourceName -> m AnyBinderInfo +lookupSourceName :: Fallible m => TopStateEx -> SourceName -> m AnyBinderInfo lookupSourceName (TopStateEx topState) v = let D.TopState bindings _ (SourceMap sourceMap) = topStateD topState in case M.lookup v sourceMap of @@ -317,13 +318,13 @@ evalUModule sourceModule = do logPass Parse sourceModule renamed <- renameSourceNames bindings sourceMap sourceModule logPass RenamePass renamed - typed <- liftEither $ inferModule bindings renamed + typed <- liftExcept $ inferModule bindings renamed -- This is a (hopefully) no-op pass. It's here as a sanity check to test the -- safer names system while we're staging it in. checkPass TypePass typed typed' <- roundtripSaferNamesPass typed checkPass TypePass typed' - synthed <- liftEither $ synthModule bindings synthCandidates typed' + synthed <- liftExcept $ synthModule bindings synthCandidates typed' -- TODO: check that the type of module exports doesn't change from here on checkPass SynthPass synthed let defunctionalized = simplifyModule bindings synthed @@ -428,24 +429,15 @@ checkPass name x = do (S.Distinct, topState) <- getTopState let scope = topBindings $ topStateD topState logPass name x - liftEither $ checkValid scope x + liftExcept $ checkValid scope x logTop $ MiscLog $ pprint name ++ " checks passed" logPass :: (MonadPasses m, Pretty a) => PassName -> a -> m n () logPass passName x = logTop $ PassInfo passName $ pprint x addResultCtx :: SourceBlock -> Result -> Result -addResultCtx block (Result outs maybeErr) = case maybeErr of - Left err -> Result outs $ Left $ addCtx block err - Right () -> Result outs $ Right () - -addCtx :: SourceBlock -> Err -> Err -addCtx block err@(Err e src s) = case src of - Nothing -> err - Just (start, stop) -> - Err e Nothing $ s ++ "\n\n" ++ ctx - where n = sbOffset block - ctx = highlightRegion (start - n, stop - n) (sbText block) +addResultCtx block (Result outs errs) = + Result outs (addSrcTextContext (sbOffset block) (sbText block) errs) logTop :: MonadPasses m => Output -> m n () logTop x = logIO [x] @@ -504,16 +496,6 @@ instance MonadPasses PassesM where requireBenchmark = PassesM $ asks \(bench, _, _) -> bench getTopState = PassesM $ asks \(_ , _, s) -> s -instance MonadError Err (PassesM n) where - throwError err = liftEitherIO $ throwError err - catchError (PassesM m) f = PassesM do - env <- ask - l <- runPassesM' getLogger - result <- runExceptT $ catchIOExcept $ runLoggerT l $ runReaderT m env - case result of - Left e -> runPassesM' $ f e - Right x -> return x - instance MonadLogger [Output] (PassesM n) where getLogger = PassesM $ lift $ getLogger diff --git a/src/lib/Type.hs b/src/lib/Type.hs index de3c5c106..d0f9cc256 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -20,7 +20,6 @@ module Type ( import Prelude hiding (pi) import Control.Monad -import Control.Monad.Except hiding (Except) import Control.Monad.Reader import Data.Foldable (toList, traverse_) import Data.Functor @@ -53,8 +52,8 @@ getType :: (HasCallStack, HasType a) => a -> Type getType x = ignoreExcept $ ctx $ runTypeCheck SkipChecks $ typeCheck x where ctx = addContext $ "Querying:\n" ++ pprint x -tryGetType :: (MonadErr m, HasCallStack, HasType a) => a -> m Type -tryGetType x = liftEither $ ctx $ runTypeCheck SkipChecks $ typeCheck x +tryGetType :: (Fallible m, HasCallStack, HasType a) => a -> m Type +tryGetType x = liftExcept $ ctx $ runTypeCheck SkipChecks $ typeCheck x where ctx = addContext $ "Querying:\n" ++ pprint x checkType :: HasType a => TypeEnv -> EffectRow -> a -> Except () @@ -71,8 +70,8 @@ traceCheckM x = traceCheck x (return ()) traceCheck :: (HasCallStack, HasVars a, HasType a) => a -> b -> b traceCheck x y = case checkType (freeVars x) Pure x of - Right () -> y - Left e -> error $ "Check failed: " ++ pprint x ++ "\n" ++ pprint e + Success () -> y + Failure e -> error $ "Check failed: " ++ pprint x ++ "\n" ++ pprint e -- === Module interfaces === @@ -488,10 +487,8 @@ goneBy ir = do curIR <- ask when (curIR >= ir) $ throw IRVariantErr $ "shouldn't appear after " ++ show ir -addExpr :: (Pretty e, MonadError Err m) => e -> m a -> m a -addExpr x m = modifyErr m \e -> case e of - Err IRVariantErr ctx s -> Err CompilerErr ctx (s ++ ": " ++ pprint x) - _ -> e +addExpr :: (Pretty e, Fallible m) => e -> m a -> m a +addExpr x m = addContext (pprint x) m -- === effects === @@ -514,7 +511,7 @@ declareEffs :: EffectRow -> TypeM () declareEffs effs = checkWithEnv \(_, allowedEffects) -> checkExtends allowedEffects effs -checkExtends :: MonadError Err m => EffectRow -> EffectRow -> m () +checkExtends :: Fallible m => EffectRow -> EffectRow -> m () checkExtends allowed (EffectRow effs effTail) = do let (EffectRow allowedEffs allowedEffTail) = allowed case effTail of @@ -647,7 +644,7 @@ typeCheckRef x = do TC (RefType _ a) <- typeCheck x return a -checkIntBaseType :: MonadError Err m => Bool -> Type -> m () +checkIntBaseType :: Fallible m => Bool -> Type -> m () checkIntBaseType allowVector t = case t of BaseTy (Scalar sbt) -> checkSBT sbt BaseTy (Vector sbt) | allowVector -> checkSBT sbt @@ -663,7 +660,7 @@ checkIntBaseType allowVector t = case t of notInt = throw TypeErr $ "Expected a fixed-width " ++ (if allowVector then "" else "scalar ") ++ "integer type, but found: " ++ pprint t -checkFloatBaseType :: MonadError Err m => Bool -> Type -> m () +checkFloatBaseType :: Fallible m => Bool -> Type -> m () checkFloatBaseType allowVector t = case t of BaseTy (Scalar sbt) -> checkSBT sbt BaseTy (Vector sbt) | allowVector -> checkSBT sbt @@ -967,14 +964,14 @@ litType v = case v of data ArgumentType = SomeFloatArg | SomeIntArg | SomeUIntArg data ReturnType = SameReturn | Word8Return -checkOpArgType :: MonadError Err m => ArgumentType -> Type -> m () +checkOpArgType :: Fallible m => ArgumentType -> Type -> m () checkOpArgType argTy x = case argTy of SomeIntArg -> checkIntBaseType True x SomeUIntArg -> assertEq x Word8Ty "" SomeFloatArg -> checkFloatBaseType True x -checkBinOp :: MonadError Err m => BinOp -> Type -> Type -> m Type +checkBinOp :: Fallible m => BinOp -> Type -> Type -> m Type checkBinOp op x y = do checkOpArgType argTy x assertEq x y "" @@ -998,7 +995,7 @@ checkBinOp op x y = do ia = SomeIntArg; fa = SomeFloatArg br = Word8Return; sr = SameReturn -checkUnOp :: MonadError Err m => UnOp -> Type -> m Type +checkUnOp :: Fallible m => UnOp -> Type -> m Type checkUnOp op x = do checkOpArgType argTy x return $ case retTy of @@ -1030,7 +1027,7 @@ indexSetConcreteSize ty = case ty of FixedIntRange low high -> Just $ fromIntegral $ high - low _ -> Nothing -checkDataLike :: MonadError Err m => String -> Type -> m () +checkDataLike :: Fallible m => String -> Type -> m () checkDataLike msg ty = case ty of Var _ -> error "Not implemented" TabTy _ b -> recur b @@ -1049,18 +1046,18 @@ checkDataLike msg ty = case ty of _ -> throw TypeErr $ pprint ty ++ msg where recur x = checkDataLike msg x -checkDataLikeDataCon :: MonadError Err m => DataConDef -> m () +checkDataLikeDataCon :: Fallible m => DataConDef -> m () checkDataLikeDataCon (DataConDef _ bs) = mapM_ (checkDataLike "data con binder" . binderAnn) bs -checkData :: MonadError Err m => Type -> m () +checkData :: Fallible m => Type -> m () checkData = checkDataLike " is not serializable" --TODO: Make this work even if the type has type variables! isData :: Type -> Bool isData ty = case checkData ty of - Left _ -> False - Right _ -> True + Failure _ -> False + Success _ -> True projectLength :: Type -> Int projectLength ty = case ty of diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 163d4a371..0bff94114 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -11,7 +11,7 @@ module Util (IsBool (..), group, ungroup, pad, padLeft, delIdx, replaceIdx, insertIdx, mvIdx, mapFst, mapSnd, splitOn, scan, scanM, composeN, mapMaybe, uncons, repeated, transitiveClosure, showErr, listDiff, splitMap, enumerate, restructure, - onSnd, onFst, highlightRegion, findReplace, swapAt, uncurry3, + onSnd, onFst, findReplace, swapAt, uncurry3, measureSeconds, bindM2, foldMapM, lookupWithIdx, (...), zipWithT, for, Zippable (..), zipWithZ_, zipErr, forMZipped, forMZipped_, @@ -145,38 +145,6 @@ restructure xs structure = evalState (traverse procLeaf structure) xs put rest return x -highlightRegion :: (Int, Int) -> String -> String -highlightRegion pos@(low, high) s - | low > high || high > length s = error $ "Bad region: \n" - ++ show pos ++ "\n" ++ s - | otherwise = - -- TODO: flag to control line numbers - -- (disabling for now because it makes quine tests tricky) - -- "Line " ++ show (1 + lineNum) ++ "\n" - - allLines !! lineNum ++ "\n" - ++ take start (repeat ' ') ++ take (stop - start) (repeat '^') ++ "\n" - where - allLines = lines s - (lineNum, start, stop) = getPosTriple pos allLines - -getPosTriple :: (Int, Int) -> [String] -> (Int, Int, Int) -getPosTriple (start, stop) lines_ = (lineNum, start - offset, stop') - where - lineLengths = map ((+1) . length) lines_ - lineOffsets = cumsum lineLengths - lineNum = maxLT lineOffsets start - offset = lineOffsets !! lineNum - stop' = min (stop - offset) (lineLengths !! lineNum) - -cumsum :: [Int] -> [Int] -cumsum xs = scanl (+) 0 xs - -maxLT :: Ord a => [a] -> a -> Int -maxLT [] _ = 0 -maxLT (x:xs) n = if n < x then -1 - else 1 + maxLT xs n - -- TODO: find a more efficient implementation findReplace :: Eq a => [a] -> [a] -> [a] -> [a] findReplace _ _ [] = []