diff --git a/README.md b/README.md index ac48727..cfcd344 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,46 @@ ssage --help # https://github.com/AnswerDotAI/shell_sage/issues ``` +### RAG Configuration + +ShellSage supports Retrieval Augmented Generation (RAG) to enhance +responses with relevant man pages: + +``` ini +[DEFAULT] +# RAG settings +use_retrieval = false # Enable/disable RAG +retrieve_limit = 3 # Number of man pages to retrieve +``` + +To use RAG functionality: + +1. Install RAG dependencies: + +``` sh +pip install 'shell_sage[rag]' +``` + +2. Build the man page vector database: + +``` sh +ssage_index +``` + +3. Enable retrieval in queries: + +``` sh +# Use RAG for a single query +ssage --use-retrieval "how do I use rsync?" + +# Or enable permanently in config +use_retrieval = true +``` + +The RAG system enhances responses by retrieving relevant man pages and +including them in the context, leading to more accurate command-line +assistance. + ## Contributing ShellSage is built using [nbdev](https://nbdev.fast.ai/). For detailed diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 2e51568..562f85e 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -44,6 +44,8 @@ "from rich.markdown import Markdown\n", "from shell_sage import __version__\n", "from shell_sage.config import *\n", + "try: from shell_sage.rag import *\n", + "except: pass\n", "from subprocess import check_output as co\n", "\n", "import os,re,subprocess,sys\n", @@ -220,78 +222,49 @@ "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "bash: no job control in this shell\n" + "alias ans='ssh answer'\n", + "alias b='ssage'\n", + "alias bc='ssage --c'\n", + "alias breaking='gh issue create -l breaking -b '\\'''\\'' -t'\n", + "alias bs='ssage --s'\n", + "alias bug='gh issue create -l bug -b '\\'''\\'' -t'\n", + "alias bump='nbdev_bump_version && commit bump'\n", + "alias enhancement='gh issue create -l enhancement -b '\\'''\\'' -t'\n", + "alias gaa='git add -A'\n", + "alias gc='git checkout'\n", + "alias gd='git diff'\n", + "alias git1st='git log --reverse --pretty=format:\"%h %an %ad : %s\" --date=local | head -1'\n", + "alias gitlog='git log -10 --pretty=format:\"%h %an %ad : %s\" --date=local'\n", + "alias gitssh='perl -pi -e '\\''s#https://github\\.com/#git\\@github.com:# if /\\[remote \"origin/../fetch =/'\\'' .git/config'\n", + "alias gp='git pull'\n", + "alias gpu='git push'\n", + "alias gs='git status'\n", + "alias issue='gh issue create'\n", + "alias issues='gh issue list'\n", + "alias jnb='jupyter nbclassic --NotebookApp.token=\"\" --NotebookApp.password=\"\"'\n", + "alias pr='git diff main | bc please generate a PR title and body for these changes. Make sure to use the github cli'\n", + "alias prep='nbdev_export && nbdev_clean && nbdev_trust'\n", + "alias qc='git diff --cached | bc please generate a concise git commit for these changes'\n", + "alias recent='ls -lth | head -n 20'\n", + "alias tb1='ssh tb1'\n", + "alias tb2='ssh tb2'\n", + "alias topypi='rm -rf dist/* && python -m build && twine upload dist/*'\n", + "alias tunnel='cloudflared tunnel --url http://localhost:5001'\n", + "alias upi='uv pip install'\n", + "alias upie='uv pip install --config-settings editable_mode=compat -e'\n", + "alias vim='nvim'\n" ] }, { - "data": { - "text/html": [ - "
alias ans='ssh answer'\n",
-       "alias b='ssage'\n",
-       "alias breaking='gh issue create -l breaking -b '\\'''\\'' -t'\n",
-       "alias bs='ssage --s'\n",
-       "alias bug='gh issue create -l bug -b '\\'''\\'' -t'\n",
-       "alias bump='nbdev_bump_version && commit bump'\n",
-       "alias enhancement='gh issue create -l enhancement -b '\\'''\\'' -t'\n",
-       "alias gaa='git add -A'\n",
-       "alias gc='git checkout'\n",
-       "alias gd='git diff'\n",
-       "alias git1st='git log --reverse --pretty=format:\"%h %an %ad : %s\" --date=local | head -1'\n",
-       "alias gitlog='git log -10 --pretty=format:\"%h %an %ad : %s\" --date=local'\n",
-       "alias gitssh='perl -pi -e '\\''s#https://github\\.com/#git\\@github.com:# if /[remote \"origin/../fetch =/'\\'' \n",
-       ".git/config'\n",
-       "alias gp='git pull'\n",
-       "alias gpu='git push'\n",
-       "alias gs='git status'\n",
-       "alias issue='gh issue create'\n",
-       "alias issues='gh issue list'\n",
-       "alias jnb='jupyter nbclassic'\n",
-       "alias prep='nbdev_export && nbdev_clean && nbdev_trust'\n",
-       "alias recent='ls -lth | head -n 20'\n",
-       "alias tb1='ssh tb1'\n",
-       "alias tb2='ssh tb2'\n",
-       "alias topypi='rm -rf dist/* && python -m build && twine upload dist/*'\n",
-       "alias tunnel='cloudflared tunnel --url http://localhost:5001'\n",
-       "alias upi='uv pip install'\n",
-       "alias upie='uv pip install --config-settings editable_mode=compat -e'\n",
-       "
\n" - ], - "text/plain": [ - "alias \u001b[33mans\u001b[0m=\u001b[32m'ssh answer'\u001b[0m\n", - "alias \u001b[33mb\u001b[0m=\u001b[32m'ssage'\u001b[0m\n", - "alias \u001b[33mbreaking\u001b[0m=\u001b[32m'gh issue create -l breaking -b '\u001b[0m\\'\u001b[32m''\u001b[0m\\'\u001b[32m' -t'\u001b[0m\n", - "alias \u001b[33mbs\u001b[0m=\u001b[32m'ssage --s'\u001b[0m\n", - "alias \u001b[33mbug\u001b[0m=\u001b[32m'gh issue create -l bug -b '\u001b[0m\\'\u001b[32m''\u001b[0m\\'\u001b[32m' -t'\u001b[0m\n", - "alias \u001b[33mbump\u001b[0m=\u001b[32m'nbdev_bump_version && commit bump'\u001b[0m\n", - "alias \u001b[33menhancement\u001b[0m=\u001b[32m'gh issue create -l enhancement -b '\u001b[0m\\'\u001b[32m''\u001b[0m\\'\u001b[32m' -t'\u001b[0m\n", - "alias \u001b[33mgaa\u001b[0m=\u001b[32m'git add -A'\u001b[0m\n", - "alias \u001b[33mgc\u001b[0m=\u001b[32m'git checkout'\u001b[0m\n", - "alias \u001b[33mgd\u001b[0m=\u001b[32m'git diff'\u001b[0m\n", - "alias \u001b[33mgit1st\u001b[0m=\u001b[32m'git log --reverse --\u001b[0m\u001b[32mpretty\u001b[0m\u001b[32m=\u001b[0m\u001b[32mformat\u001b[0m\u001b[32m:\"%h %an %ad : %s\" --\u001b[0m\u001b[32mdate\u001b[0m\u001b[32m=\u001b[0m\u001b[32mlocal\u001b[0m\u001b[32m | head -1'\u001b[0m\n", - "alias \u001b[33mgitlog\u001b[0m=\u001b[32m'git log -10 --\u001b[0m\u001b[32mpretty\u001b[0m\u001b[32m=\u001b[0m\u001b[32mformat\u001b[0m\u001b[32m:\"%h %an %ad : %s\" --\u001b[0m\u001b[32mdate\u001b[0m\u001b[32m=\u001b[0m\u001b[32mlocal\u001b[0m\u001b[32m'\u001b[0m\n", - "alias \u001b[33mgitssh\u001b[0m=\u001b[32m'perl -pi -e '\u001b[0m\\'\u001b[32m's#https://github\\.com/#git\\@github.com:# if /\u001b[0m\u001b[32m[\u001b[0m\u001b[32mremote \"origin/../fetch =/'\u001b[0m\\'\u001b[32m' \u001b[0m\n", - "\u001b[32m.git/config'\u001b[0m\n", - "alias \u001b[33mgp\u001b[0m=\u001b[32m'git pull'\u001b[0m\n", - "alias \u001b[33mgpu\u001b[0m=\u001b[32m'git push'\u001b[0m\n", - "alias \u001b[33mgs\u001b[0m=\u001b[32m'git status'\u001b[0m\n", - "alias \u001b[33missue\u001b[0m=\u001b[32m'gh issue create'\u001b[0m\n", - "alias \u001b[33missues\u001b[0m=\u001b[32m'gh issue list'\u001b[0m\n", - "alias \u001b[33mjnb\u001b[0m=\u001b[32m'jupyter nbclassic'\u001b[0m\n", - "alias \u001b[33mprep\u001b[0m=\u001b[32m'nbdev_export && nbdev_clean && nbdev_trust'\u001b[0m\n", - "alias \u001b[33mrecent\u001b[0m=\u001b[32m'ls -lth | head -n 20'\u001b[0m\n", - "alias \u001b[33mtb1\u001b[0m=\u001b[32m'ssh tb1'\u001b[0m\n", - "alias \u001b[33mtb2\u001b[0m=\u001b[32m'ssh tb2'\u001b[0m\n", - "alias \u001b[33mtopypi\u001b[0m=\u001b[32m'rm -rf dist/* && python -m build && twine upload dist/*'\u001b[0m\n", - "alias \u001b[33mtunnel\u001b[0m=\u001b[32m'cloudflared tunnel --url http://localhost:5001'\u001b[0m\n", - "alias \u001b[33mupi\u001b[0m=\u001b[32m'uv pip install'\u001b[0m\n", - "alias \u001b[33mupie\u001b[0m=\u001b[32m'uv pip install --config-settings \u001b[0m\u001b[32meditable_mode\u001b[0m\u001b[32m=\u001b[0m\u001b[32mcompat\u001b[0m\u001b[32m -e'\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "bash: cannot set terminal process group (99809): Inappropriate ioctl for device\n", + "bash: no job control in this shell\n" + ] } ], "source": [ @@ -324,92 +297,54 @@ "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "bash: no job control in this shell\n" + "\n", + "Darwin Nathans-MacBook-Air.local 24.1.0 Darwin Kernel Version 24.1.0: Thu Oct 10 21:02:26 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T8122 arm64\n", + "/bin/bash\n", + "\n", + "alias ans='ssh answer'\n", + "alias b='ssage'\n", + "alias bc='ssage --c'\n", + "alias breaking='gh issue create -l breaking -b '\\'''\\'' -t'\n", + "alias bs='ssage --s'\n", + "alias bug='gh issue create -l bug -b '\\'''\\'' -t'\n", + "alias bump='nbdev_bump_version && commit bump'\n", + "alias enhancement='gh issue create -l enhancement -b '\\'''\\'' -t'\n", + "alias gaa='git add -A'\n", + "alias gc='git checkout'\n", + "alias gd='git diff'\n", + "alias git1st='git log --reverse --pretty=format:\"%h %an %ad : %s\" --date=local | head -1'\n", + "alias gitlog='git log -10 --pretty=format:\"%h %an %ad : %s\" --date=local'\n", + "alias gitssh='perl -pi -e '\\''s#https://github\\.com/#git\\@github.com:# if /\\[remote \"origin/../fetch =/'\\'' .git/config'\n", + "alias gp='git pull'\n", + "alias gpu='git push'\n", + "alias gs='git status'\n", + "alias issue='gh issue create'\n", + "alias issues='gh issue list'\n", + "alias jnb='jupyter nbclassic --NotebookApp.token=\"\" --NotebookApp.password=\"\"'\n", + "alias pr='git diff main | bc please generate a PR title and body for these changes. Make sure to use the github cli'\n", + "alias prep='nbdev_export && nbdev_clean && nbdev_trust'\n", + "alias qc='git diff --cached | bc please generate a concise git commit for these changes'\n", + "alias recent='ls -lth | head -n 20'\n", + "alias tb1='ssh tb1'\n", + "alias tb2='ssh tb2'\n", + "alias topypi='rm -rf dist/* && python -m build && twine upload dist/*'\n", + "alias tunnel='cloudflared tunnel --url http://localhost:5001'\n", + "alias upi='uv pip install'\n", + "alias upie='uv pip install --config-settings editable_mode=compat -e'\n", + "alias vim='nvim'\n", + "\n", + "\n" ] }, { - "data": { - "text/html": [ - "
<system_info>\n",
-       "<system>Darwin Nathans-MacBook-Air.local 24.1.0 Darwin Kernel Version 24.1.0: Thu Oct 10 21:02:26 PDT 2024; \n",
-       "root:xnu-11215.41.3~2/RELEASE_ARM64_T8122 arm64</system>\n",
-       "<shell>/bin/bash</shell>\n",
-       "<aliases>\n",
-       "alias ans='ssh answer'\n",
-       "alias b='ssage'\n",
-       "alias breaking='gh issue create -l breaking -b '\\'''\\'' -t'\n",
-       "alias bs='ssage --s'\n",
-       "alias bug='gh issue create -l bug -b '\\'''\\'' -t'\n",
-       "alias bump='nbdev_bump_version && commit bump'\n",
-       "alias enhancement='gh issue create -l enhancement -b '\\'''\\'' -t'\n",
-       "alias gaa='git add -A'\n",
-       "alias gc='git checkout'\n",
-       "alias gd='git diff'\n",
-       "alias git1st='git log --reverse --pretty=format:\"%h %an %ad : %s\" --date=local | head -1'\n",
-       "alias gitlog='git log -10 --pretty=format:\"%h %an %ad : %s\" --date=local'\n",
-       "alias gitssh='perl -pi -e '\\''s#https://github\\.com/#git\\@github.com:# if /[remote \"origin/../fetch =/'\\'' \n",
-       ".git/config'\n",
-       "alias gp='git pull'\n",
-       "alias gpu='git push'\n",
-       "alias gs='git status'\n",
-       "alias issue='gh issue create'\n",
-       "alias issues='gh issue list'\n",
-       "alias jnb='jupyter nbclassic'\n",
-       "alias prep='nbdev_export && nbdev_clean && nbdev_trust'\n",
-       "alias recent='ls -lth | head -n 20'\n",
-       "alias tb1='ssh tb1'\n",
-       "alias tb2='ssh tb2'\n",
-       "alias topypi='rm -rf dist/* && python -m build && twine upload dist/*'\n",
-       "alias tunnel='cloudflared tunnel --url http://localhost:5001'\n",
-       "alias upi='uv pip install'\n",
-       "alias upie='uv pip install --config-settings editable_mode=compat -e'\n",
-       "</aliases>\n",
-       "</system_info>\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m<\u001b[0m\u001b[1;95msystem_info\u001b[0m\u001b[39m>\u001b[0m\n", - "\u001b[39mDarwin Nathans-MacBook-Air.local \u001b[0m\u001b[1;36m24.1\u001b[0m\u001b[39m.\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m Darwin Kernel Version \u001b[0m\u001b[1;36m24.1\u001b[0m\u001b[39m.\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m: Thu Oct \u001b[0m\u001b[1;36m10\u001b[0m\u001b[39m \u001b[0m\u001b[1;92m21:02:26\u001b[0m\u001b[39m PDT \u001b[0m\u001b[1;36m2024\u001b[0m\u001b[39m; \u001b[0m\n", - "\u001b[39mroot:xnu-\u001b[0m\u001b[1;36m11215.41\u001b[0m\u001b[39m.\u001b[0m\u001b[1;36m3\u001b[0m\u001b[39m~\u001b[0m\u001b[1;36m2\u001b[0m\u001b[39m/RELEASE_ARM64_T8122 arm64<\u001b[0m\u001b[35m/\u001b[0m\u001b[95msystem\u001b[0m\u001b[39m>\u001b[0m\n", - "\u001b[39m\u001b[0m\u001b[35m/bin/\u001b[0m\u001b[95mbash\u001b[0m\u001b[39m<\u001b[0m\u001b[35m/\u001b[0m\u001b[95mshell\u001b[0m\u001b[39m>\u001b[0m\n", - "\u001b[39m\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mans\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ssh answer'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mb\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ssage'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mbreaking\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'gh issue create -l breaking -b '\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m''\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m' -t'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mbs\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ssage --s'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mbug\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'gh issue create -l bug -b '\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m''\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m' -t'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mbump\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'nbdev_bump_version && commit bump'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33menhancement\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'gh issue create -l enhancement -b '\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m''\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m' -t'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgaa\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git add -A'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgc\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git checkout'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgd\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git diff'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgit1st\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git log --reverse --\u001b[0m\u001b[32mpretty\u001b[0m\u001b[32m=\u001b[0m\u001b[32mformat\u001b[0m\u001b[32m:\"%h %an %ad : %s\" --\u001b[0m\u001b[32mdate\u001b[0m\u001b[32m=\u001b[0m\u001b[32mlocal\u001b[0m\u001b[32m | head -1'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgitlog\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git log -10 --\u001b[0m\u001b[32mpretty\u001b[0m\u001b[32m=\u001b[0m\u001b[32mformat\u001b[0m\u001b[32m:\"%h %an %ad : %s\" --\u001b[0m\u001b[32mdate\u001b[0m\u001b[32m=\u001b[0m\u001b[32mlocal\u001b[0m\u001b[32m'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgitssh\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'perl -pi -e '\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m's#https://github\\.com/#git\\@github.com:# if /\u001b[0m\u001b[32m[\u001b[0m\u001b[32mremote \"origin/../fetch =/'\u001b[0m\u001b[39m\\'\u001b[0m\u001b[32m' \u001b[0m\n", - "\u001b[32m.git/config'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgp\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git pull'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgpu\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git push'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mgs\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'git status'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33missue\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'gh issue create'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33missues\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'gh issue list'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mjnb\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'jupyter nbclassic'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mprep\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'nbdev_export && nbdev_clean && nbdev_trust'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mrecent\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ls -lth | head -n 20'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mtb1\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ssh tb1'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mtb2\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'ssh tb2'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mtopypi\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'rm -rf dist/* && python -m build && twine upload dist/*'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mtunnel\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'cloudflared tunnel --url http://localhost:5001'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mupi\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'uv pip install'\u001b[0m\n", - "\u001b[39malias \u001b[0m\u001b[33mupie\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'uv pip install --config-settings \u001b[0m\u001b[32meditable_mode\u001b[0m\u001b[32m=\u001b[0m\u001b[32mcompat\u001b[0m\u001b[32m -e'\u001b[0m\n", - "\u001b[39m<\u001b[0m\u001b[35m/\u001b[0m\u001b[95maliases\u001b[0m\u001b[39m>\u001b[0m\n", - "\u001b[39m<\u001b[0m\u001b[35m/\u001b[0m\u001b[95msystem_info\u001b[0m\u001b[1m>\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "bash: no job control in this shell\n" + ] } ], "source": [ @@ -444,38 +379,10 @@ "execution_count": null, "id": "e799e039", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n",
-       "Global options:\n",
-       "  -q, --quiet                                      Do not print any output\n",
-       "  -v, --verbose...                                 Use verbose output\n",
-       "      --color <COLOR_CHOICE>                       Control colors in output  \n",
-       "      --native-tls                                 Whether to load TLS certificates from the platform's native \n",
-       "certificate store \n",
-       "      --offline  \n",
-       "
\n" - ], - "text/plain": [ - "\n", - "Global options:\n", - " -q, --quiet Do not print any output\n", - " -v, --verbose\u001b[33m...\u001b[0m Use verbose output\n", - " --color \u001b[1m<\u001b[0m\u001b[1;95mCOLOR_CHOICE\u001b[0m\u001b[1m>\u001b[0m Control colors in output \n", - " --native-tls Whether to load TLS certificates from the platform's native \n", - "certificate store \n", - " --offline \n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "p = get_pane(20)\n", - "print(p[:512])" + "# p = get_pane(20)\n", + "# print(p[:512])" ] }, { @@ -497,59 +404,10 @@ "execution_count": null, "id": "a537f788", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
<pane id=%0 active>\n",
-       "Global options:\n",
-       "  -q, --quiet                                      Do not print any output\n",
-       "  -v, --verbose...                                 Use verbose output\n",
-       "      --color <COLOR_CHOICE>                       Control colors in output  \n",
-       "      --native-tls                                 Whether to load TLS certificates from the platform's native \n",
-       "certificate store [env:\n",
-       "                                                   UV_NATIVE_TLS=\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m<\u001b[0m\u001b[1;95mpane\u001b[0m\u001b[39m \u001b[0m\u001b[33mid\u001b[0m\u001b[39m=%\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m active>\u001b[0m\n", - "\u001b[39mGlobal options:\u001b[0m\n", - "\u001b[39m -q, --quiet Do not print any output\u001b[0m\n", - "\u001b[39m -v, --verbose\u001b[0m\u001b[33m...\u001b[0m\u001b[39m Use verbose output\u001b[0m\n", - "\u001b[39m --color \u001b[0m Control colors in output \n", - " --native-tls Whether to load TLS certificates from the platform's native \n", - "certificate store \u001b[1m[\u001b[0menv:\n", - " \u001b[33mUV_NATIVE_TLS\u001b[0m=\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "ps = get_panes(20)\n", - "print(ps[:512])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28e60780", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'2000'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "co(['tmux', 'display-message', '-p', '#{history-limit}'], text=True).strip()" + "# ps = get_panes(20)\n", + "# print(ps[:512])" ] }, { @@ -561,8 +419,7 @@ "source": [ "#| export\n", "def tmux_history_lim():\n", - " lim = co(['tmux', 'display-message', '-p', '#{history-limit}'], text=True).strip()\n", - " return int(lim)" + " return int(co(['tmux', 'display-message', '-p', '#{history-limit}'], text=True).strip())" ] }, { @@ -602,6 +459,102 @@ " except subprocess.CalledProcessError: return None" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb0b5827", + "metadata": {}, + "outputs": [], + "source": [ + "# hist = get_history(20)\n", + "# print(hist[:512])" + ] + }, + { + "cell_type": "markdown", + "id": "ca18af6a", + "metadata": {}, + "source": [ + "## RAG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51fec0d3", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def fmt_doc(r):\n", + " return f'\\n{r.content}\\n'\n", + "\n", + "def get_docs(q: str, limit: int=16, threshold: float=0.5):\n", + " df = tbl.search(q, limit, threshold)\n", + " docs = [fmt_doc(r) for r in df.itertuples()]\n", + " print(md(f'```\\nRetrieved the following man pages: {\", \".join(df[\"package_name\"])}\\n```'))\n", + " return f'\\n{\"\\n\".join(docs)}\\n'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ddd3bba", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "NAME\n", + " rsync - faster, flexible replacement for rcp\n", + "\n", + "SYNOPSIS\n", + " rsync [OPTION]... SRC [SRC]... DEST\n", + "\n", + " rsync [OPTION]... SRC [SRC]... [USER@]HOST:DEST\n", + "\n", + " rsync [OPTION]... S\n" + ] + } + ], + "source": [ + "print(get_docs('How can I do backups with rsync?', limit=16, threshold=0.2)[:256])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f8ecfa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "NAME\n", + " unzip - list, test and extract compressed files in a ZIP archive\n", + "\n", + "SYNOPSIS\n", + " unzip [-Z] [-cflptTuvz[abjnoqsCDKLMUVWX$/:^]] file[.zip] [file(s) ...]\n", + " [-x xfile(s) ...] [-d exdir]\n", + "\n", + "DESCRIPTION\n", + " unzip will list, test, or extract files from a ZIP archive, commonly\n", + " found on MS-DOS systems. The default behavior (with no options) is to\n", + " extract into the current directory (and subdirectories below it) all\n", + " \n" + ] + } + ], + "source": [ + "print(get_docs('How do I unzip a file?', limit=1)[:512])" + ] + }, { "cell_type": "markdown", "id": "fc100d13", @@ -631,26 +584,10 @@ "execution_count": null, "id": "366bbf4d", "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "```json\n", - "{'model': 'claude-3-5-sonnet-20241022', 'provider': 'anthropic'}\n", - "```" - ], - "text/plain": [ - "{'provider': 'anthropic', 'model': 'claude-3-5-sonnet-20241022'}" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "opts = get_opts(provider=None, model=None)\n", - "opts" + "# opts = get_opts(provider=None, model=None)\n", + "# opts" ] }, { @@ -876,6 +813,8 @@ " history_lines: int = None, # Number of history lines. Defaults to tmux scrollback history length\n", " s: bool = False, # Enable sassy mode\n", " c: bool = False, # Enable command mode\n", + " use_retrieval: bool = False, # Retrieve relevant man pages and add them to the prompt. Must have installed `shell_sage[rag]`.\n", + " retrieve_limit: int = 5, # Number of documents to retrieve.\n", " provider: str = None, # The LLM Provider\n", " model: str = None, # The LLM model that will be invoked on the LLM provider\n", " base_url: str = None,\n", @@ -886,7 +825,8 @@ "):\n", " opts = get_opts(history_lines=history_lines, provider=provider, model=model,\n", " base_url=base_url, api_key=api_key, code_theme=code_theme,\n", - " code_lexer=code_lexer)\n", + " code_lexer=code_lexer, use_retrieval=use_retrieval,\n", + " retrieve_limit=retrieve_limit)\n", "\n", " mode = 'default'\n", " if s: mode = 'sassy'\n", @@ -899,7 +839,7 @@ " print(f\"{datetime.now()} | Starting ShellSage request with options {opts}\")\n", " md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer, inline_code_theme=opts.code_theme)\n", " query = ' '.join(query)\n", - " ctxt = '' if skip_system else _sys_info()\n", + " ctxt = [''] if skip_system else [_sys_info()]\n", "\n", " # Get tmux history if in a tmux session\n", " if os.environ.get('TMUX'):\n", @@ -907,14 +847,20 @@ " if opts.history_lines is None or opts.history_lines < 0:\n", " opts.history_lines = tmux_history_lim()\n", " history = get_history(opts.history_lines,pid)\n", - " if history: ctxt += f'\\n{history}\\n'\n", + " if history: ctxt += [f'\\n{history}\\n']\n", "\n", + " if opts.use_retrieval:\n", + " if verbosity>0: print(f\"{datetime.now()} | Retrieving relevant man pages\")\n", + " try: ctxt += [get_docs(query, limit=opts.retrieve_limit)]\n", + " except ImportError:\n", + " raise Exception('Must have installed `shell_sage[rag]` to retrieve man pages')\n", " # Read from stdin if available\n", " if not sys.stdin.isatty(): \n", " if verbosity>0: print(f\"{datetime.now()} | Adding stdin to prompt\")\n", - " ctxt += f'\\n\\n{sys.stdin.read()}'\n", + " ctxt += [f'\\n{sys.stdin.read()}']\n", " \n", " if verbosity>0: print(f\"{datetime.now()} | Finalizing prompt\")\n", + " ctxt = '\\n'.join(ctxt)\n", " query = f'{ctxt}\\n\\n{query}\\n'\n", " query = [mk_msg(query)] if opts.provider == 'openai' else query\n", "\n", diff --git a/nbs/01_config.ipynb b/nbs/01_config.ipynb index 02037f1..9e217dd 100644 --- a/nbs/01_config.ipynb +++ b/nbs/01_config.ipynb @@ -148,6 +148,8 @@ " base_url: str = ''\n", " api_key: str = ''\n", " history_lines: int = -1\n", + " retrieve_limit: int = 5\n", + " use_retrieval: bool = True\n", " code_theme: str = \"monokai\"\n", " code_lexer: str = \"python\"" ] @@ -161,7 +163,7 @@ { "data": { "text/plain": [ - "ShellSageConfig(provider='anthropic', model='claude-3-5-sonnet-20241022', base_url='', api_key='', history_lines=-1, code_theme='monokai', code_lexer='python')" + "ShellSageConfig(provider='anthropic', model='claude-3-5-sonnet-20241022', base_url='', api_key='', history_lines=-1, retrieve_limit=5, code_theme='monokai', code_lexer='python')" ] }, "execution_count": null, @@ -195,30 +197,11 @@ "execution_count": null, "id": "efd3d92c", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'provider': 'anthropic', 'model': 'claude-3-5-sonnet-20241022', 'base_url': '', 'api_key': '', 'history_lines': '-1', 'code_theme': 'monokai', 'code_lexer': 'python'}" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "cfg = get_cfg()\n", - "cfg" + "# cfg = get_cfg()\n", + "# cfg" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5483a5d", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/02_rag.ipynb b/nbs/02_rag.ipynb new file mode 100644 index 0000000..6a1883a --- /dev/null +++ b/nbs/02_rag.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "929f165e", + "metadata": {}, + "outputs": [], + "source": [ + "#|default_exp rag" + ] + }, + { + "cell_type": "markdown", + "id": "8d5eaf1e", + "metadata": {}, + "source": [ + "# ShellSage Retrieval Augmented Generation" + ] + }, + { + "cell_type": "markdown", + "id": "11feb5d9", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2f9566d", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from chonkie import SentenceChunker\n", + "from fastcore.all import *\n", + "from fastprogress.fastprogress import progress_bar\n", + "from lancedb import connect\n", + "from lancedb.pydantic import LanceModel, Vector\n", + "from lancedb.rerankers import LinearCombinationReranker\n", + "from lancedb.table import LanceTable\n", + "from pathlib import Path\n", + "from sentence_transformers import SentenceTransformer\n", + "from subprocess import check_output as co\n", + "\n", + "import os, subprocess\n", + "os.environ['TOKENIZERS_PARALLELISM'] = 'false'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e326ae3", + "metadata": {}, + "outputs": [], + "source": [ + "import random" + ] + }, + { + "cell_type": "markdown", + "id": "7ade9bcf", + "metadata": {}, + "source": [ + "## Database setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1ecb07e", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "# set up db path in user's home cache directory\n", + "db_path = Path.home() / '.cache' / 'shell_sage' / 'db'\n", + "db_path.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b122b4c5", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "chunker = SentenceChunker(tokenizer=\"gpt2\", chunk_size=2_048,\n", + " chunk_overlap=256, min_sentences_per_chunk=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19f10a4f", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "model = SentenceTransformer('thenlper/gte-small')\n", + "ndim = model.encode([\"Example sentence\"]).shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ae76c15", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "384" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ndim" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6bd2a8", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class EmbeddingTable(LanceModel):\n", + " content: str\n", + " package_name: str\n", + " embedding: Vector(ndim)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bb739fc", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "db = connect(db_path)\n", + "tbl = None\n", + "try: tbl = db.open_table(\"man_pages\")\n", + "except ValueError:\n", + " tbl = db.create_table(\"man_pages\", schema=EmbeddingTable, mode=\"create\")\n", + " tbl.create_fts_index(\"content\") # for hybrid search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca74754e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(#1) [Path('/Users/nathan/.cache/shell_sage/db/man_pages.lance')]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_path.ls()" + ] + }, + { + "cell_type": "markdown", + "id": "91867f2c", + "metadata": {}, + "source": [ + "## Man pages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88a378cd", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _section(cmd, section):\n", + " s = co(f'man {cmd} | col -b | sed -n \"/^{section}/,/^[A-Z]/p\" | sed \"$d\"',\n", + " shell=True, stderr=subprocess.DEVNULL, text=True).strip()\n", + " return '\\n'.join(s.splitlines()[:-1]).strip()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fb0119e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SYNOPSIS\n", + " ls [-@ABCFGHILOPRSTUWabcdefghiklmnopqrstuvwxy1%,] [--color=when]\n", + "\t[-D format] [file ...]\n", + "DESCRIPTION\n", + " For each operand that names a file of a type other than directory, ls\n", + " displays its name as well as any req\n" + ] + } + ], + "source": [ + "print(_section('ls', 'SYNOPSIS')[:128])\n", + "print(_section('ls', 'DESCRIPTION')[:128])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fca42b0a", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _get_page(cmd):\n", + " name = _section(cmd, 'NAME')\n", + " synopsis = _section(cmd, 'SYNOPSIS')\n", + " description = _section(cmd, 'DESCRIPTION')\n", + " examples = _section(cmd, 'EXAMPLES')\n", + " return cmd, f\"{name}\\n\\n{synopsis}\\n\\n{description}\\n\\n{examples}\".strip()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78631c04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NAME\n", + " ls – list directory contents\n", + "\n", + "SYNOPSIS\n", + " ls [-@ABCFGHILOPRSTUWabcdefghiklmnopqrstuvwxy1%,] [--color=when]\n", + "\t[-D format] [file ...]\n", + "\n", + "DESCRIPTION\n", + " For each operand that names a file of a type other than directory, ls\n", + " displays its name as\n", + "...\n", + "used in conjunction with the -l option.\n", + "\n", + "EXAMPLES\n", + " List the contents of the current working directory in long format:\n", + "\n", + "\t $ ls -l\n", + "\n", + " In addition to listing the contents of the current working directory in\n", + " long format, show inode numbers, file \n" + ] + } + ], + "source": [ + "print(_get_page('ls')[1][:256])\n", + "print('...')\n", + "print(_get_page('ls')[1][-512:-256])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5671d91b", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def _manpages(lim=None):\n", + " lines = L(co(['apropos', '-s', '1', '.'], text=True).strip().splitlines())\n", + " lines = lines.map(lambda s: s.split(\"(\")[0].strip())\n", + " cmds = lines.filter(lambda s: s).unique()[:lim]\n", + " pages = parallel(_get_page, cmds, progress=progress_bar).filter(lambda x: x[1])\n", + " return zip(*pages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7773976", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "127" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmds, pages = _manpages(lim=128)\n", + "len(pages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "574f6044", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "git-p4\n", + "NAME\n", + " git-p4 - Import from and submit to Perforce repositories\n", + "\n", + "SYNOPSIS\n", + " git p4 clone [] [ threshold]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9532341c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
contentpackage_nameembeddingscore
0\\t Note the quotes around *.c. The file hell...git-checkout[-0.041774247, -0.044317152, 0.0674934, 0.0194...1.000000
1NAME\\n git-checkout - Switch branches or...git-checkout[-0.048017204, -0.049874607, 0.05528505, -0.00...0.955922
2Thus you can, e.g., turn a library subd...git-filter-branch[-0.043320876, -0.013584119, 0.04771582, 0.008...0.642050
3not to fetch them again. See also the p...git-branch[-0.061778784, -0.031455133, 0.05729562, -0.01...0.505731
\n", + "
" + ], + "text/plain": [ + " content package_name \\\n", + "0 \\t Note the quotes around *.c. The file hell... git-checkout \n", + "1 NAME\\n git-checkout - Switch branches or... git-checkout \n", + "2 Thus you can, e.g., turn a library subd... git-filter-branch \n", + "3 not to fetch them again. See also the p... git-branch \n", + "\n", + " embedding score \n", + "0 [-0.041774247, -0.044317152, 0.0674934, 0.0194... 1.000000 \n", + "1 [-0.048017204, -0.049874607, 0.05528505, -0.00... 0.955922 \n", + "2 [-0.043320876, -0.013584119, 0.04771582, 0.008... 0.642050 \n", + "3 [-0.061778784, -0.031455133, 0.05729562, -0.01... 0.505731 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = tbl.search('How can I change my current branch?', limit=8)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9f5a489", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NAME\n", + " git-checkout - Switch branches or restore working tree files\n", + "\n", + "SYNOPSIS\n", + " git checkout [-q] [-f] [-m] []\n", + " git checkout [-q] [-f] [-m] --detach []\n", + " git checkout [-q] [-f] [-m] [--detach] \n", + " git checkout [-q] [-f] [-m] [[-b|-B|--orphan] ] []\n", + " git checkout [-f|--ours|--theirs|-m|--conflict=