Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions unison-cli/src/Unison/CommandLine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ module Unison.CommandLine
parseInput,
prompt,
reportParseFailure,

-- * Shared Helpers
defaultLoadSourceFile,
defaultWriteSourceFile,
)
where

Expand All @@ -22,9 +26,12 @@ import Data.Map qualified as Map
import Data.Text qualified as Text
import Data.Vector qualified as Vector
import System.FilePath (takeFileName)
import System.IO.Error (isDoesNotExistError)
import Text.Numeral (defaultInflection)
import Text.Numeral.Language.ENG qualified as Numeral
import Text.Regex.TDFA ((=~))
import Unison.Cli.Monad (LoadSourceResult)
import Unison.Cli.Monad qualified as Cli
import Unison.Codebase (Codebase)
import Unison.Codebase.Branch (Branch0)
import Unison.Codebase.Branch qualified as Branch
Expand All @@ -43,6 +50,8 @@ import Unison.Prelude
import Unison.PrettyTerminal qualified as PrettyTerm
import Unison.Symbol (Symbol)
import Unison.Util.Pretty qualified as P
import UnliftIO (catch)
import UnliftIO.Directory qualified as Directory
import Prelude hiding (readFile, writeFile)

allow :: FilePath -> Bool
Expand All @@ -51,6 +60,30 @@ allow p =
not (".#" `isPrefixOf` takeFileName p)
&& (isSuffixOf ".u" p || isSuffixOf ".uu" p)

defaultWriteSourceFile :: Text -> Text -> Bool -> IO ()
defaultWriteSourceFile fp contents addFold = do
path <- Directory.canonicalizePath (Text.unpack fp)
prependUtf8
path
if addFold
then contents <> "\n\n---- Anything below this line is ignored by Unison.\n\n"
else contents <> "\n\n"

defaultLoadSourceFile :: Text -> IO LoadSourceResult
defaultLoadSourceFile fname =
if allow $ Text.unpack fname
then
let handle :: IOException -> IO LoadSourceResult
handle e =
case e of
_ | isDoesNotExistError e -> return Cli.InvalidSourceNameError
_ -> return Cli.LoadError
go = do
contents <- readUtf8 $ Text.unpack fname
return $ Cli.LoadSuccess contents
in catch go handle
else return Cli.InvalidSourceNameError

data ExpansionFailure
= TooManyArguments (NonEmpty InputPattern.Argument)
| UnexpectedStructuredArgument StructuredArgument
Expand Down
31 changes: 3 additions & 28 deletions unison-cli/src/Unison/CommandLine/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Unison.CommandLine.Main
where

import Compat (withInterruptHandler)
import Control.Exception (catch, displayException, mask)
import Control.Exception (displayException, mask)
import Control.Lens ((?~))
import Control.Lens.Lens
import Crypto.Random qualified as Random
Expand All @@ -21,7 +21,6 @@ import System.Console.Haskeline qualified as Line
import System.Console.Haskeline.History qualified as Line
import System.FSNotify qualified as FSNotify
import System.IO (hGetEcho, hPutStrLn, hSetEcho, stderr, stdin)
import System.IO.Error (isDoesNotExistError)
import U.Codebase.Sqlite.Queries qualified as Queries
import Unison.Auth.CredentialManager qualified as AuthN
import Unison.Auth.HTTPClient (AuthenticatedHttpClient)
Expand Down Expand Up @@ -57,7 +56,6 @@ import Unison.Symbol (Symbol)
import Unison.Syntax.Parser qualified as Parser
import Unison.Util.Pretty qualified as P
import UnliftIO qualified
import UnliftIO.Directory qualified as Directory
import UnliftIO.STM

getUserInput ::
Expand Down Expand Up @@ -235,20 +233,6 @@ main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl uc
pp
getProjectRoot
(loopState ^. #numberedArgs)
let loadSourceFile :: Text -> IO Cli.LoadSourceResult
loadSourceFile fname =
if allow $ Text.unpack fname
then
let handle :: IOException -> IO Cli.LoadSourceResult
handle e =
case e of
_ | isDoesNotExistError e -> return Cli.InvalidSourceNameError
_ -> return Cli.LoadError
go = do
contents <- readUtf8 $ Text.unpack fname
return $ Cli.LoadSuccess contents
in catch go handle
else return Cli.InvalidSourceNameError
let notify :: Output -> IO ()
notify =
notifyUser (pure dir) fetchIssueFromGitHub
Expand Down Expand Up @@ -282,23 +266,14 @@ main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl uc
]
action

let writeSource :: Text -> Text -> Bool -> IO ()
writeSource fp contents addFold = do
path <- Directory.canonicalizePath (Text.unpack fp)
prependUtf8
path
if addFold
then contents <> "\n\n---- Anything below this line is ignored by Unison.\n\n"
else contents <> "\n\n"

let env =
Cli.Env
{ authHTTPClient,
codebase,
credentialManager,
loadSource = loadSourceFile,
loadSource = defaultLoadSourceFile,
lspCheckForChanges,
writeSource,
writeSource = defaultWriteSourceFile,
generateUniqueName = Parser.uniqueBase32Namegen <$> Random.getSystemDRG,
notify,
notifyNumbered = \o ->
Expand Down
77 changes: 55 additions & 22 deletions unison-cli/src/Unison/MCP/Cli.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Unison.MCP.Cli
( handleInputMCP,
ppForProjectContext,
cliToMCP,
virtualSourceName,
)
where

Expand All @@ -19,8 +20,10 @@ import Unison.Cli.Monad qualified as Cli
import Unison.Codebase qualified as Codebase
import Unison.Codebase.Editor.HandleInput qualified as HandleInput
import Unison.Codebase.Editor.Input (Event, Input)
import Unison.Codebase.Editor.Output qualified as Output
import Unison.Codebase.Path qualified as Path
import Unison.Codebase.ProjectPath qualified as PP
import Unison.CommandLine (defaultLoadSourceFile, defaultWriteSourceFile)
import Unison.CommandLine.OutputMessages qualified as Output
import Unison.MCP.Types
import Unison.MCP.Types qualified as MCP
Expand All @@ -31,24 +34,29 @@ import Unison.Util.Pretty qualified as Pretty
import UnliftIO.STM
import Prelude hiding (readFile, writeFile)

virtualSourceName :: Text
virtualSourceName = "<mcp-virtual-source>"

data CliOutput = CliOutput
{ sourceCodeUpdates :: [Text],
outputMessages :: [Text]
outputMessages :: [Text],
errorMessages :: [Text]
}
deriving (Eq, Show)

instance Semigroup CliOutput where
CliOutput src1 out1 <> CliOutput src2 out2 =
CliOutput (src1 <> src2) (out1 <> out2)
CliOutput src1 out1 errs1 <> CliOutput src2 out2 errs2 =
CliOutput (src1 <> src2) (out1 <> out2) (errs1 <> errs2)

instance Monoid CliOutput where
mempty = CliOutput [] []
mempty = CliOutput [] [] []

instance ToJSON CliOutput where
toJSON (CliOutput sourceCodeUpdates outputMessages) =
toJSON (CliOutput sourceCodeUpdates outputMessages errorMessages) =
object
[ "sourceCodeUpdates" .= sourceCodeUpdates,
"outputMessages" .= outputMessages
"outputMessages" .= outputMessages,
"errorMessages" .= errorMessages
]

ppForProjectContext :: ProjectContext -> ExceptT Text Transaction PP.ProjectPath
Expand All @@ -64,14 +72,24 @@ ppForProjectContext ProjectContext {projectName, branchName} = do

handleInputMCP :: ProjectContext -> [Either Event Input] -> ExceptT Text MCP CliOutput
handleInputMCP projectContext input = do
case input of
(inp : rest) -> do
(_, cliOutput) <- cliToMCP projectContext (HandleInput.loop inp)
(cliOutput <>) <$> handleInputMCP projectContext rest
[] -> pure mempty
hasErroredVar <- newTVarIO False
let onErr _errMsg = atomically $ writeTVar hasErroredVar True
result <- cliToMCP projectContext onErr do
Cli.labelE \fail' -> do
for_ input \inp -> do
HandleInput.loop inp
readTVarIO hasErroredVar >>= \case
False -> pure ()
True -> fail' "An error occurred during input handling."
case result of
(Nothing, cliOut) -> pure cliOut
(Just (Left err), cliOutput) ->
pure $ cliOutput <> mempty {errorMessages = [err]}
(Just (Right ()), cliOutput) ->
pure cliOutput

cliToMCP :: ProjectContext -> Cli.Cli a -> ExceptT Text MCP (Maybe a, CliOutput)
cliToMCP projCtx cli = do
cliToMCP :: ProjectContext -> (Text -> IO ()) -> Cli.Cli a -> ExceptT Text MCP (Maybe a, CliOutput)
cliToMCP projCtx onError cli = do
MCP.Env {ucmVersion, codebase, runtime, workDir} <- ask
initialPP <- ExceptT . liftIO $ Codebase.runTransactionExceptT codebase $ do
ppForProjectContext projCtx
Expand All @@ -80,22 +98,31 @@ cliToMCP projCtx cli = do
tokenProvider = AuthN.newTokenProvider credMan
authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion
outputVar <- newTVarIO Seq.empty
errorsVar <- newTVarIO Seq.empty
sourceCodeUpdatesVar <- newTVarIO Seq.empty
let notify output = do
pretty <- Output.notifyUser workDir Output.fetchIssueFromGitHub output
atomically $ modifyTVar outputVar (<> Seq.singleton pretty)
if (Output.isFailure output)
then do
atomically $ modifyTVar errorsVar (<> Seq.singleton pretty)
liftIO $ onError (Pretty.toPlain 0 pretty)
else do
atomically $ modifyTVar outputVar (<> Seq.singleton pretty)
let notifyNumbered output = do
let (pretty, nargs) = Output.notifyNumbered output
atomically $ modifyTVar outputVar (<> Seq.singleton pretty)
pure nargs

let loadSource = error "loadSource is not implemented for the MCP server."
let writeSource _sourceName content replace = do
if replace
then do
atomically $ writeTVar sourceCodeUpdatesVar (Seq.singleton content)
let writeSource sourceName content replace = do
if sourceName == virtualSourceName
then
if replace
then do
atomically $ writeTVar sourceCodeUpdatesVar (Seq.singleton content)
else do
atomically $ modifyTVar sourceCodeUpdatesVar (<> Seq.singleton content)
else do
atomically $ modifyTVar sourceCodeUpdatesVar (<> Seq.singleton content)
defaultWriteSourceFile sourceName content replace

seedRef <- liftIO $ newIORef (0 :: Int)
let cliEnv =
Expand All @@ -106,7 +133,7 @@ cliToMCP projCtx cli = do
generateUniqueName = do
i <- atomicModifyIORef' seedRef \i -> let !i' = i + 1 in (i', i)
pure (Parser.uniqueBase32Namegen (Random.drgNewSeed (Random.seedFromInteger (fromIntegral i)))),
loadSource,
loadSource = defaultLoadSourceFile,
lspCheckForChanges = \_ -> pure (),
writeSource,
notify,
Expand All @@ -124,15 +151,21 @@ cliToMCP projCtx cli = do
-- flush the output buffer since it should now be filled.
cliOut <- atomically $ do
msgs <- readTVar outputVar
errs <- readTVar errorsVar
sourceCodeUpdates <- toList <$> readTVar sourceCodeUpdatesVar
let outputMessages =
msgs
& fmap (Pretty.toPlain 0)
& toList
let errorMessages =
errs
& fmap (Pretty.toPlain 0)
& toList
pure $
( CliOutput
{ sourceCodeUpdates,
outputMessages
outputMessages,
errorMessages
}
)
case cliResult of
Expand Down
Loading