From 81631f4b9c3c91d0f5facd4194e30b497025b2b7 Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Wed, 30 Apr 2025 18:32:43 +0200 Subject: [PATCH 1/6] Implement full GitHub integration - Add OAuth flow - Enable repository selection - Add entry synchronization; including resync logic - Support GitHub disconnection --- .../components/github/GitHubRepoSelector.tsx | 150 +++++++ src/interface/web/app/settings/page.tsx | 389 +++++++++++------- src/khoj/configure.py | 4 + .../content/github/github_to_entries.py | 43 ++ src/khoj/routers/api_github.py | 222 ++++++++++ 5 files changed, 652 insertions(+), 156 deletions(-) create mode 100644 src/interface/web/app/components/github/GitHubRepoSelector.tsx create mode 100644 src/khoj/routers/api_github.py diff --git a/src/interface/web/app/components/github/GitHubRepoSelector.tsx b/src/interface/web/app/components/github/GitHubRepoSelector.tsx new file mode 100644 index 000000000..876f4b169 --- /dev/null +++ b/src/interface/web/app/components/github/GitHubRepoSelector.tsx @@ -0,0 +1,150 @@ +import React, { useEffect, useRef, useState } from "react"; +import { Button } from "@/components/ui/button"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Card, CardContent, CardHeader } from "@/components/ui/card"; +import { useToast } from "@/components/ui/use-toast"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Loader2, Files } from "lucide-react"; + +interface Repo { + name: string; + owner: string; + branch: string; + full_name: string; + description?: string; + private: boolean; + selected?: boolean; +} + +export default function GitHubRepoSelector({ openExternally }: { openExternally?: boolean }) { + const [repos, setRepos] = useState([]); + const [selected, setSelected] = useState>(new Set()); + const [open, setOpen] = useState(openExternally || false); + const [loading, setLoading] = useState(false); + const [wasAutoOpened, setWasAutoOpened] = useState(false); + const { toast } = useToast(); + // This is the time we wait for the backend to process the request and create the files for the user + + const hasLoaded = useRef(false); + + useEffect(() => { + if (hasLoaded.current) return; + hasLoaded.current = true; + + if (typeof window === "undefined") return; + + const params = new URLSearchParams(window.location.search); + const shouldOpen = openExternally || params.get("github_connected") === "true"; + + setLoading(true); + fetch("/api/github/repos") + .then((res) => res.json()) + .then((data) => { + setRepos(data); + setSelected(new Set(data.filter((r) => r.selected).map((r) => r.full_name))); + + if (shouldOpen) { + setOpen(true); + setWasAutoOpened(true); + params.delete("github_connected"); + const newUrl = `${window.location.pathname}?${params.toString()}`; + window.history.replaceState({}, "", newUrl); + } + }) + .catch(() => toast({ title: "⚠️ Failed to load GitHub repos" })) + .finally(() => setLoading(false)); + }, []); + + const toggleRepo = (fullName: string) => { + const next = new Set(selected); + if (next.has(fullName)) next.delete(fullName); + else next.add(fullName); + setSelected(next); + }; + + const submitSelection = async () => { + const selectedRepos = repos.filter((r) => selected.has(r.full_name)); + const payload = selectedRepos.map(({ name, owner, branch }) => ({ name, owner, branch })); + const res = await fetch("/api/github/repos/select", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ repos: payload }), + }); + + if (res.ok) { + toast({ + title: "✅ GitHub repos saved", + description: `Selected ${payload.length} repos.`, + }); + setOpen(false); + } else { + toast({ title: "❌ Failed to save GitHub repos" }); + } + }; + + return ( + { + if (!nextOpen && wasAutoOpened && selected.size === 0) { + const confirmed = window.confirm( + "You haven't selected any repositories. If you close this window, GitHub integration will remain inactive. Are you sure you want to continue?", + ); + if (!confirmed) return; + } + setOpen(nextOpen); + }} + > + + + + + + Select GitHub Repositories to Index + + + + {loading ? ( +
+ +
+ ) : ( + repos.map((repo) => ( +
+ toggleRepo(repo.full_name)} + /> +
+
{repo.full_name}
+
+ {repo.description || "No description"} +
+
+
+ )) + )} +
+ {!loading && ( +
+ +
+ )} +
+
+
+ ); +} diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index 5983c5913..a42c27fe1 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -23,8 +23,15 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { - AlertDialog, AlertDialogAction, AlertDialogCancel, - AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table"; @@ -63,6 +70,7 @@ import { } from "@phosphor-icons/react"; import Loading from "../components/loading/loading"; +import GitHubRepoSelector from "../components/github/GitHubRepoSelector"; import IntlTelInput from "intl-tel-input/react"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; @@ -72,7 +80,8 @@ import { KhojLogoType } from "../components/logo/khojLogo"; import { Progress } from "@/components/ui/progress"; import JSZip from "jszip"; -import { saveAs } from 'file-saver'; +import { saveAs } from "file-saver"; +import { useRouter } from "next/router"; interface DropdownComponentProps { items: ModelOptions[]; @@ -81,7 +90,12 @@ interface DropdownComponentProps { callbackFunc: (value: string) => Promise; } -const DropdownComponent: React.FC = ({ items, selected, isActive, callbackFunc }) => { +const DropdownComponent: React.FC = ({ + items, + selected, + isActive, + callbackFunc, +}) => { const [position, setPosition] = useState(selected?.toString() ?? "0"); return ( @@ -114,7 +128,10 @@ const DropdownComponent: React.FC = ({ items, selected, value={item.id.toString()} disabled={!isActive && item.tier !== "free"} > - {item.name} {item.tier === "standard" && (Futurist)} + {item.name}{" "} + {item.tier === "standard" && ( + (Futurist) + )} ))} @@ -315,6 +332,8 @@ export default function SettingsView() { const { toast } = useToast(); const isMobileWidth = useIsMobileWidth(); + const [githubConnected, setGithubConnected] = useState(false); + const title = "Settings"; const cardClassName = @@ -327,11 +346,15 @@ export default function SettingsView() { initialUserConfig?.is_phone_number_verified ? PhoneNumberValidationState.Verified : initialUserConfig?.phone_number - ? PhoneNumberValidationState.SendOTP - : PhoneNumberValidationState.Setup, + ? PhoneNumberValidationState.SendOTP + : PhoneNumberValidationState.Setup, ); setName(initialUserConfig?.given_name); setNotionToken(initialUserConfig?.notion_token ?? null); + if (typeof window !== "undefined") { + const params = new URLSearchParams(window.location.search); + setGithubConnected(params.get("github_connected") === "true"); + } }, [initialUserConfig]); const sendOTP = async () => { @@ -524,13 +547,14 @@ export default function SettingsView() { const updateModel = (modelType: string) => async (id: string) => { // Get the selected model from the options - const modelOptions = modelType === "chat" - ? userConfig?.chat_model_options - : modelType === "paint" - ? userConfig?.paint_model_options - : userConfig?.voice_model_options; - - const selectedModel = modelOptions?.find(model => model.id.toString() === id); + const modelOptions = + modelType === "chat" + ? userConfig?.chat_model_options + : modelType === "paint" + ? userConfig?.paint_model_options + : userConfig?.voice_model_options; + + const selectedModel = modelOptions?.find((model) => model.id.toString() === id); const modelName = selectedModel?.name; // Check if the model is free tier or if the user is active @@ -551,7 +575,8 @@ export default function SettingsView() { }, }); - if (!response.ok) throw new Error(`Failed to switch ${modelType} model to ${modelName}`); + if (!response.ok) + throw new Error(`Failed to switch ${modelType} model to ${modelName}`); toast({ title: `✅ Switched ${modelType} model to ${modelName}`, @@ -570,7 +595,7 @@ export default function SettingsView() { setIsExporting(true); // Get total conversation count - const statsResponse = await fetch('/api/chat/stats'); + const statsResponse = await fetch("/api/chat/stats"); const stats = await statsResponse.json(); const total = stats.num_conversations; setTotalConversations(total); @@ -586,7 +611,7 @@ export default function SettingsView() { conversations.push(...data); setExportedConversations((page + 1) * 10); - setExportProgress(((page + 1) * 10 / total) * 100); + setExportProgress((((page + 1) * 10) / total) * 100); } // Add conversations to zip @@ -605,7 +630,7 @@ export default function SettingsView() { toast({ title: "Export Failed", description: "Failed to export chats. Please try again.", - variant: "destructive" + variant: "destructive", }); } finally { setIsExporting(false); @@ -693,6 +718,12 @@ export default function SettingsView() { setNotionToken(newUserConfig.notion_token); } else if (source === "github") { newUserConfig.enabled_content_source.github = false; + fetch("/api/github/disconnect", { + method: "DELETE", + headers: { + "Content-Type": "application/json", + }, + }); } setUserConfig(newUserConfig); } @@ -808,93 +839,93 @@ export default function SettingsView() { )) || (userConfig.subscription_state === "subscribed" && ( - <> -

- Futurist -

-

- Subscription renews on{" "} - - { - userConfig.subscription_renewal_date - } - -

- - )) || + <> +

+ Futurist +

+

+ Subscription renews on{" "} + + { + userConfig.subscription_renewal_date + } + +

+ + )) || (userConfig.subscription_state === "unsubscribed" && ( - <> -

Futurist

+ <> +

Futurist

+

+ Subscription ends on{" "} + + { + userConfig.subscription_renewal_date + } + +

+ + )) || + (userConfig.subscription_state === + "expired" && ( + <> +

Humanist

+ {(userConfig.subscription_renewal_date && (

- Subscription ends on{" "} + Subscription expired{" "} + on{" "} { userConfig.subscription_renewal_date }

- - )) || - (userConfig.subscription_state === - "expired" && ( - <> -

Humanist

- {(userConfig.subscription_renewal_date && ( -

- Subscription expired{" "} - on{" "} - - { - userConfig.subscription_renewal_date - } - -

- )) || ( -

- Check{" "} - - pricing page - {" "} - to compare plans. -

- )} - - ))} + )) || ( +

+ Check{" "} + + pricing page + {" "} + to compare plans. +

+ )} + + ))} {(userConfig.subscription_state == "subscribed" && ( + + )) || + (userConfig.subscription_state == + "unsubscribed" && ( )) || - (userConfig.subscription_state == - "unsubscribed" && ( - - )) || (userConfig.subscription_enabled_trial_at && ( + {githubConnected || + userConfig.enabled_content_source.github ? ( + + ) : ( + + )} - ) : /* Show set API key button notion oauth url not set setup */ - !userConfig.notion_oauth_url ? ( - - ) : ( - <> - ) + userConfig.enabled_content_source.notion && + notionToken === + userConfig.notion_token ? ( + + ) : /* Show set API key button notion oauth url not set setup */ + !userConfig.notion_oauth_url ? ( + + ) : ( + <> + ) } @@ -1245,7 +1298,11 @@ export default function SettingsView() {

- This will delete all your account data, including conversations, agents, and any assets you{"'"}ve generated. Be sure to export before you do this if you want to keep your information. + This will delete all your account data, + including conversations, agents, and any + assets you{"'"}ve generated. Be sure to + export before you do this if you want to + keep your information.

@@ -1261,36 +1318,56 @@ export default function SettingsView() { - Are you absolutely sure? + + Are you absolutely sure? + - This action is irreversible. This will permanently delete your account - and remove all your data from our servers. + This action is irreversible. + This will permanently delete + your account and remove all your + data from our servers. - Cancel + + Cancel + { try { - const response = await fetch('/api/self', { - method: 'DELETE' - }); - if (!response.ok) throw new Error('Failed to delete account'); + const response = + await fetch( + "/api/self", + { + method: "DELETE", + }, + ); + if (!response.ok) + throw new Error( + "Failed to delete account", + ); toast({ title: "Account Deleted", - description: "Your account has been successfully deleted.", + description: + "Your account has been successfully deleted.", }); // Redirect to home page after successful deletion - window.location.href = "/"; + window.location.href = + "/"; } catch (error) { - console.error('Error deleting account:', error); + console.error( + "Error deleting account:", + error, + ); toast({ title: "Error", - description: "Failed to delete account. Please try again or contact support.", - variant: "destructive" + description: + "Failed to delete account. Please try again or contact support.", + variant: + "destructive", }); } }} diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 40d61a888..195ce860d 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -314,6 +314,7 @@ def configure_routes(app): from khoj.routers.api_agents import api_agents from khoj.routers.api_chat import api_chat from khoj.routers.api_content import api_content + from khoj.routers.api_github import github_router from khoj.routers.api_model import api_model from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client @@ -323,9 +324,12 @@ def configure_routes(app): app.include_router(api_agents, prefix="/api/agents") app.include_router(api_model, prefix="/api/model") app.include_router(api_content, prefix="/api/content") + app.include_router(github_router, prefix="/api/github") app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) + logger.info("🛣️ API Routes configured") + if not state.anonymous_mode: from khoj.routers.auth import auth_router diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 31f99f844..3fcf28b10 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -1,10 +1,12 @@ import logging +import re import time from typing import Dict, List, Tuple import requests from magika import Magika +from khoj.database.adapters import EntryAdapters from khoj.database.models import Entry as DbEntry from khoj.database.models import GithubConfig, KhojUser from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries @@ -54,6 +56,10 @@ def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = Fals logger.warning( f"Github PAT token is not set. Private repositories cannot be indexed and lower rate limits apply." ) + + if user: + self.resync_github_entries(user) + current_entries = [] for repo in self.config.repos: current_entries += self.process_repo(repo) @@ -113,6 +119,43 @@ def update_entries_with_ids(self, current_entries, user: KhojUser = None): return num_new_embeddings, num_deleted_embeddings + def resync_github_entries(self, user: KhojUser = None) -> None: + """ + Resync GitHub entries for the user. + + This ensures that if a user deselects a repo, its files are no longer indexed. + Does not add or update entries — call `process()` separately for full re-index. + """ + + config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() + if config: + # Fetch all GitHub Entries for the user + files = EntryAdapters.get_all_filenames_by_source(user, "github") + raw_repos = config.githubrepoconfig.all() + repos = [] + for repo in raw_repos: + repos.append(repo.owner + "/" + repo.name) + + if files: + # Check if the entries' repository is still selected in the config + for file in files: + # We need to extract the repo name and owner from the entry's file path + # https://{url}/{owner}/{name}}/blob/... + match = re.search(r"github\.com/([^/]+)/([^/]+)", file) + if not match: + logger.warning(f"Unable to parse repo from file path: {file}") + continue + + owner = match.group(1) + name = match.group(2) + # Construct the repo name + repo_name = f"{owner}/{name}" + + if repo_name and repo_name not in repos: + # If not, delete the entry + logger.debug(f"Deleting entry {file} as the repo {repo_name} is not selected anymore") + EntryAdapters.delete_entry_by_file(user, file) + def get_files(self, repo_url: str, repo: GithubRepoConfig): # Get the contents of the repository repo_content_url = f"{repo_url}/git/trees/{repo.branch}" diff --git a/src/khoj/routers/api_github.py b/src/khoj/routers/api_github.py new file mode 100644 index 000000000..e9c3ce5ac --- /dev/null +++ b/src/khoj/routers/api_github.py @@ -0,0 +1,222 @@ +import logging +import os +import secrets +from typing import Optional + +import httpx +from fastapi import APIRouter, BackgroundTasks, Request +from fastapi.responses import JSONResponse, RedirectResponse +from starlette.authentication import requires + +from khoj.database import adapters +from khoj.database.models import GithubConfig, GithubRepoConfig, KhojUser +from khoj.processor.content.github.github_to_entries import GithubToEntries + +github_router = APIRouter() +logger = logging.getLogger(__name__) + +# Replace these with your GitHub OAuth app credentials +GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID") +GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET") +GITHUB_REDIRECT_URI = os.getenv("GITHUB_REDIRECT_URI") + +# In-memory store for testing (use a database in production) +oauth_state_store = {} + + +def save_oauth_state(state: str, user: KhojUser) -> None: + oauth_state_store[state] = user # Store the state and user mapping + + +def get_user_id_by_oauth_state(state: str) -> Optional[KhojUser]: + return oauth_state_store.pop(state, None) # Remove the state after use + + +def index_github(user: KhojUser): + config = GithubConfig.objects.filter(user=user).first() + if config: + GithubToEntries(config).process(files={}, user=user, regenerate=False) + logger.info(f"Github entries indexed for user {user.id}") + + +@github_router.get("/connect") +@requires(["authenticated"]) +async def connect_github(request: Request): + """ + Redirect the user to GitHub's OAuth authorization page. + """ + user = request.user + if not user.is_authenticated: + return JSONResponse(content={"error": "User not authenticated"}, status_code=401) + + # Generate a unique state value + state = secrets.token_urlsafe(16) + + # Save the state and user ID mapping (e.g., in a database or in-memory store) + save_oauth_state(state, user) # Implement this function to store the mapping + + github_oauth_url = ( + f"https://github.com/login/oauth/authorize" + f"?client_id={GITHUB_CLIENT_ID}&redirect_uri={GITHUB_REDIRECT_URI}&scope=repo,user" + f"&state={state}" + ) + return RedirectResponse(url=github_oauth_url) + + +@github_router.get("/callback") +async def github_callback(request: Request): + code = request.query_params.get("code") + state = request.query_params.get("state") + + if not code or not state: + logger.error("Missing code or state in GitHub callback") + return RedirectResponse(url="/settings") + + user = get_user_id_by_oauth_state(state) + if not user: + logger.error("Invalid or expired OAuth state") + return RedirectResponse(url="/settings") + + if not user or not hasattr(user, "object"): + logger.error("OAuth state returned invalid user") + return RedirectResponse(url="/settings") + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + headers={"Accept": "application/json"}, + data={ + "client_id": GITHUB_CLIENT_ID, + "client_secret": GITHUB_CLIENT_SECRET, + "code": code, + "redirect_uri": GITHUB_REDIRECT_URI, + "state": state, + }, + ) + + if response.status_code != 200: + logger.error(f"GitHub token exchange failed: {response.text}") + return RedirectResponse(url="/settings") + + token_data = response.json() + access_token = token_data.get("access_token") + if not access_token: + logger.error("No access token returned from GitHub") + return RedirectResponse(url="/settings") + + except Exception as e: + logger.exception("Exception during GitHub token exchange") + return RedirectResponse(url="/settings") + + try: + # Save the GitHub access token + config = await adapters.GithubConfig.objects.filter(user=user.object).afirst() + if not config: + config = await adapters.GithubConfig.objects.acreate(pat_token=access_token, user=user.object) + else: + config.pat_token = access_token + await config.asave() + await config.githubrepoconfig.all().adelete() + + logger.info(f"GitHub integration successfully set up for user {user.object.id}") + settings_redirect = str(request.app.url_path_for("config_page")) + + logger.info(f"Redirecting to Settings config page: {settings_redirect}") + + return RedirectResponse(settings_redirect + "?github_connected=true") + + except Exception as e: + logger.exception("Failed to save GitHub configuration") + return RedirectResponse(url="/settings") + + +@github_router.get("/repos") +@requires(["authenticated"]) +async def list_user_repos(request: Request): + user = request.user + if not user.is_authenticated: + return JSONResponse({"error": "Not authenticated"}, status_code=401) + + config = await GithubConfig.objects.filter(user=user.object).prefetch_related("githubrepoconfig").afirst() + if not config: + return JSONResponse({"error": "GitHub not connected"}, status_code=400) + + logger.debug(f"GitHub config for user {user.object.id}: config: {config.id}") + + raw_repos = config.githubrepoconfig.all() + selected_repos = [] + for repo in raw_repos: + selected_repos.append(repo.owner + "/" + repo.name) + logger.debug(f"Repos from DB: {selected_repos}") + + headers = {"Authorization": f"token {config.pat_token}"} + async with httpx.AsyncClient() as client: + response = await client.get("https://api.github.com/user/repos", headers=headers) + + if response.status_code != 200: + return JSONResponse( + {"error": "Failed to fetch repos", "detail": response.text}, status_code=response.status_code + ) + + repos = response.json() + return [ + { + "name": r["name"], + "owner": r["owner"]["login"], + "branch": r["default_branch"], + "full_name": r["full_name"], + "description": r.get("description"), + "private": r.get("private", False), + "selected": r["full_name"] in selected_repos, # ✅ new flag + } + for r in repos + ] + + +@github_router.post("/repos/select") +@requires(["authenticated"]) +async def select_user_repos(request: Request, background_tasks: BackgroundTasks): + user = request.user + if not user.is_authenticated: + return JSONResponse({"error": "Not authenticated"}, status_code=401) + + body = await request.json() + repos = body.get("repos", []) + if not repos: + return JSONResponse({"error": "No repositories provided"}, status_code=400) + + config = await GithubConfig.objects.filter(user=user.object).afirst() + if not config: + return JSONResponse({"error": "GitHub not connected"}, status_code=400) + + await config.githubrepoconfig.all().adelete() # clear old selections + + for repo in repos: + await GithubRepoConfig.objects.acreate( + name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config + ) + + # Trigger an async job to index_github. Let it run without blocking the response. + background_tasks.add_task(index_github, user.object) + + return {"status": "success", "count": len(repos)} + + +@github_router.delete("/disconnect") +@requires(["authenticated"]) +async def disconnect_github(request: Request): + """ + Disconnect the GitHub integration for the authenticated user. + """ + user = request.user + if not user.is_authenticated: + return JSONResponse(content={"error": "User not authenticated"}, status_code=401) + + # Delete the GitHub configuration for the user + await GithubConfig.objects.filter(user=user.object).adelete() + + logger.info(f"GitHub integration successfully set up for user {user.object.id}") + settings_redirect = str(request.app.url_path_for("config_page")) + + return RedirectResponse(settings_redirect + "?github_connected=false") From 88301782b6bbca39393fef09642baa88756ce4bf Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Wed, 30 Apr 2025 18:42:25 +0200 Subject: [PATCH 2/6] Refactor GitHubRepoSelector to initialize selected repos based on fetched data --- .../web/app/components/github/GitHubRepoSelector.tsx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/interface/web/app/components/github/GitHubRepoSelector.tsx b/src/interface/web/app/components/github/GitHubRepoSelector.tsx index 876f4b169..5b981f713 100644 --- a/src/interface/web/app/components/github/GitHubRepoSelector.tsx +++ b/src/interface/web/app/components/github/GitHubRepoSelector.tsx @@ -47,7 +47,13 @@ export default function GitHubRepoSelector({ openExternally }: { openExternally? .then((res) => res.json()) .then((data) => { setRepos(data); - setSelected(new Set(data.filter((r) => r.selected).map((r) => r.full_name))); + setSelected( + new Set( + data + .filter((r: { selected: boolean }) => r.selected) + .map((r: { full_name: string }) => r.full_name), + ), + ); if (shouldOpen) { setOpen(true); From 76a66c052da4ae95f2b059fdd307405c5da89f5c Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Tue, 6 May 2025 14:52:22 +0200 Subject: [PATCH 3/6] Add BackgroundServiceConfig model and ServiceManager for managing background tasks --- src/khoj/database/admin.py | 13 + src/khoj/database/models/__init__.py | 13 + src/khoj/main.py | 19 ++ .../processor/content/github/github_sync.py | 25 ++ src/khoj/utils/service_manager.py | 244 ++++++++++++++++++ 5 files changed, 314 insertions(+) create mode 100644 src/khoj/processor/content/github/github_sync.py create mode 100644 src/khoj/utils/service_manager.py diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 7297ce118..a94ec1c72 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -17,6 +17,7 @@ from khoj.database.models import ( Agent, AiModelApi, + BackgroundServiceConfig, ChatModel, ClientApplication, Conversation, @@ -423,3 +424,15 @@ def get_subscription_type(self, obj): get_subscription_type.short_description = "Subscription Type" # type: ignore get_subscription_type.admin_order_field = "user__subscription__type" # type: ignore + + +@admin.register(BackgroundServiceConfig) +class BackgroundServiceConfigAdmin(admin.ModelAdmin): + list_display = ("task_name", "task_interval", "task_last_run", "task_next_run", "task_is_enabled") + list_editable = ("task_interval", "task_is_enabled") + list_filter = ("task_is_enabled",) + search_fields = ("task_name", "task_id") + ordering = ("task_name",) + + def has_add_permission(self, request): + return False # ❌ disables the "Add" button in the admin panel because we don't want to add new tasks manually diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index bd49aa8cd..29c1aff40 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -365,6 +365,19 @@ class GithubConfig(DbBaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) +class BackgroundServiceConfig(models.Model): + class Meta: + verbose_name = "Background Service Config" + verbose_name_plural = "Background Service Configs" + + task_id = models.CharField(max_length=200) + task_name = models.CharField(max_length=200) + task_interval = models.IntegerField(default=3600) # interval in seconds (1 hour) + task_last_run = models.DateTimeField(null=True, default=None, blank=True) + task_next_run = models.DateTimeField(null=True, default=None, blank=True) + task_is_enabled = models.BooleanField(default=True) + + class GithubRepoConfig(DbBaseModel): name = models.CharField(max_length=200) owner = models.CharField(max_length=200) diff --git a/src/khoj/main.py b/src/khoj/main.py index 57794ebb0..256d1ce4f 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -16,6 +16,7 @@ from importlib.metadata import version from khoj.utils.helpers import in_debug_mode, is_env_var_true +from khoj.utils.service_manager import ServiceManager, BackgroundService # Ignore non-actionable warnings warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning) @@ -101,6 +102,9 @@ SCHEDULE_LEADER_NAME = ProcessLock.Operation.SCHEDULE_LEADER +# Initialize Service Manager +service_manager = ServiceManager() + def shutdown_scheduler(): logger.info("🌑 Shutting down Khoj") @@ -109,12 +113,27 @@ def shutdown_scheduler(): logger.info("🔓 Schedule Leader released") ProcessLockAdapters.remove_process_lock(state.schedule_leader_process_lock) + # Stop all background services + service_manager.stop_all() + try: state.scheduler.shutdown() except Exception as e: logger.debug(f"Did not shutdown scheduler: {e}") +# Initialize Background Services +@app.on_event("startup") +def initialize_background_services(): + from khoj.processor.content.github.github_sync import github_sync_task + + # Register the GitHub sync task + service_manager.register_service(name="github_sync", interval=3600, fn=github_sync_task) # 1 hour + + # Start the background services + service_manager.start_all() + + def run(should_start_server=True): # Turn Tokenizers Parallelism Off. App does not support it. os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/src/khoj/processor/content/github/github_sync.py b/src/khoj/processor/content/github/github_sync.py new file mode 100644 index 000000000..dd7b86271 --- /dev/null +++ b/src/khoj/processor/content/github/github_sync.py @@ -0,0 +1,25 @@ +import logging + +from khoj.database.models import GithubConfig +from khoj.processor.content.github.github_to_entries import GithubToEntries + +logger = logging.getLogger(__name__) + + +def github_sync_task(): + """ + This module contains the GitHub sync task that runs periodically to sync + GitHub repositories with the Khoj database. + If the task returns False, it will stop running. + If the task returns True, it will continue running. + """ + + logger.info("Running GitHub sync...") + + # Fetch all GitHub configurations + configs = GithubConfig.objects.all() + if configs: + for config in configs: + GithubToEntries(config).process(files={}, user=config.user, regenerate=False) + + return True diff --git a/src/khoj/utils/service_manager.py b/src/khoj/utils/service_manager.py new file mode 100644 index 000000000..7e8a24760 --- /dev/null +++ b/src/khoj/utils/service_manager.py @@ -0,0 +1,244 @@ +import asyncio +import logging +import time +from typing import Callable, Dict, Optional + +from asgiref.sync import sync_to_async +from django.utils import timezone + +logger = logging.getLogger(__name__) + + +class BackgroundService: + + """ + A class to manage background services that run periodically. + Each service can be started, stopped, and registered with a name and interval. + """ + + def __init__(self, name: str, interval: int, fn: Callable, run_immediately: bool = True): + self.name = name + self.interval = interval + self.fn = fn + self.run_immediately = run_immediately + self.task: Optional[asyncio.Task] = None + self.should_continue = True + + from khoj.database.models import BackgroundServiceConfig + + self.config: BackgroundServiceConfig = None + + async def loadConfig(self) -> None: + """ + Load the configuration for the service from the database. + This is a placeholder for loading any necessary configuration. + """ + from khoj.database.models import BackgroundServiceConfig + + try: + # Remove old config + if self.config: + self.config = None + + self.config = await sync_to_async(BackgroundServiceConfig.objects.get)(task_name=self.name) + + self.should_continue = self.config.task_is_enabled + self.interval = self.config.task_interval + + if not self.config.task_is_enabled: + logger.info(f"[{self.name}] Service is disabled.") + self.run_immediately = False + self.should_continue = False + return + + will_save = False + + # Time checks + current_time = timezone.now().timestamp() + + if self.config.task_last_run and (current_time - self.config.task_last_run.timestamp()) < self.interval: + self.run_immediately = True + else: + will_save = True + + if self.config.task_next_run: + diff = self.config.task_next_run.timestamp() - current_time + if diff > 0: + self.run_immediately = False + + self.interval = diff + + else: + self.run_immediately = True + else: + self.config.task_next_run = timezone.now() + timezone.timedelta(seconds=self.interval) + will_save = True + + if will_save: + await sync_to_async(self.config.save)() + + logger.info(f"[{self.name}] Loaded config: {self.config}") + except BackgroundServiceConfig.DoesNotExist: + logger.warning(f"[{self.name}] No config found. Creating default entry.") + await sync_to_async(BackgroundServiceConfig.objects.create)( + task_id=f"default-{self.name}", + task_name=self.name, + task_interval=self.interval, + task_last_run=timezone.now(), + task_next_run=timezone.now() + timezone.timedelta(seconds=self.interval), + task_is_enabled=True, + ) + except Exception as e: + logger.error(f"[{self.name}] Error loading config: {e}") + self.should_continue = False + self.run_immediately = False + logger.info(f"[{self.name}] Service is disabled due to error.") + + async def updateConfig(self) -> bool: + """ + Update the configuration for the service in the database. + This is a placeholder for updating any necessary configuration. + """ + from khoj.database.models import BackgroundServiceConfig + + result = True + try: + self.config = await sync_to_async(BackgroundServiceConfig.objects.get)(task_name=self.name) + + self.config.task_last_run = timezone.now() + self.config.task_next_run = timezone.now() + timezone.timedelta(seconds=self.interval) + self.interval = self.config.task_interval + + if not self.config.task_is_enabled: + result = False + + await sync_to_async(self.config.save)() + except Exception as e: + logger.error(f"[{self.name}] Error updating config: {e}") + + return result + + async def _safe_call(self) -> bool: + try: + # Check if the service function is async or sync + if asyncio.iscoroutinefunction(self.fn): + result = await self.fn() + else: + result = await sync_to_async(self.fn)() + + result = bool(result) # Treat None or False as stop signal + + if result: + # if the function returns True, the configuration is updated + # if the function returns False, the service was disabled + # and the configuration is updated + if not await self.updateConfig(): + result = False + + return result + except Exception as e: + logger.error(f"[{self.name}] Error: {e}") + return True # Keep trying even after failure, unless you want to halt + + async def start(self): + if not self.config: + await self.loadConfig() + + if self.run_immediately: + self.should_continue = await self._safe_call() + + while self.should_continue: + await asyncio.sleep(self.interval) + if not self.should_continue: + break + self.should_continue = await self._safe_call() + + logger.info(f"[{self.name}] Stopped.") + if self.task: + self.task.cancel() + self.task = None + logger.info(f"[{self.name}] Task cancelled.") + + async def stop(self): + if self.task: + self.task.cancel() + self.should_continue = False + logger.info(f"[{self.name}] Stopped.") + self.task = None + + +class ServiceManager: + def __init__(self): + self.services: Dict[str, BackgroundService] = {} + + self.register_service( + name="service_manager", interval=60, fn=self._service_manager_task, run_immediately=False # 1 minute + ) + + async def _service_manager_task(self) -> bool: + """ + This task runs every minute to check the status of all registered services. + It can be used to perform any necessary maintenance or updates. + """ + + logger.info("Service Manager Task running...") + for service in self.services.values(): + if service.name == "service_manager": + continue + + await service.loadConfig() + + if service.task and not service.task.done(): + if service.config and service.config.task_is_enabled: + logger.info(f"Service '{service.name}' is running and should be.") + elif service.config and not service.config.task_is_enabled: + logger.info(f"Service '{service.name}' is running, but should not be.") + await service.stop() + else: + logger.info(f"Service '{service.name}' is running, but no config found.") + + elif service.config and service.config.task_is_enabled and not service.task: + logger.info(f"Service '{service.name}' is not running, but should be.") + + service.task = asyncio.create_task(service.start()) + elif service.config and not service.config.task_is_enabled and not service.task: + logger.info(f"Service '{service.name}' is not running and should not be.") + # service.config = None + + return True + + def register_service( + self, name: str, interval: int, fn: Callable, run_immediately: bool = True + ) -> BackgroundService: + if name in self.services: + raise ValueError(f"Service '{name}' already registered.") + service = BackgroundService(name, interval, fn, run_immediately) + self.services[name] = service + + return service + + def get_service(self, name: str) -> Optional[BackgroundService]: + return self.services.get(name) + + def start_all(self) -> None: + for service in self.services.values(): + service.task = asyncio.create_task(service.start()) + + def delete_service(self, name: str) -> Optional[BackgroundService]: + if name in self.services: + service = self.services.pop(name) + if service.task: + service.task.cancel() + logger.info(f"Service '{name}' stopped.") + return service + else: + logger.warning(f"Service '{name}' not found.") + return None + + def stop_all(self) -> None: + for service in self.services.values(): + if service.task: + service.task.cancel() + logger.info(f"Service '{service.name}' stopped.") + self.services.clear() + logger.info("All services stopped and cleared.") From 56cfe5ebf729f2ea6888e19812e0315f0f1b0976 Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Tue, 6 May 2025 14:52:36 +0200 Subject: [PATCH 4/6] Enhance GithubToEntries with retry logic for file downloads and add regenerate option to update_entries_with_ids --- .../content/github/github_to_entries.py | 50 ++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/khoj/processor/content/github/github_to_entries.py b/src/khoj/processor/content/github/github_to_entries.py index 3fcf28b10..dadab0abe 100644 --- a/src/khoj/processor/content/github/github_to_entries.py +++ b/src/khoj/processor/content/github/github_to_entries.py @@ -64,7 +64,7 @@ def process(self, files: dict[str, str], user: KhojUser, regenerate: bool = Fals for repo in self.config.repos: current_entries += self.process_repo(repo) - return self.update_entries_with_ids(current_entries, user=user) + return self.update_entries_with_ids(current_entries, user=user, regenerate=regenerate) def process_repo(self, repo: GithubRepoConfig): repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}" @@ -105,7 +105,7 @@ def process_repo(self, repo: GithubRepoConfig): return current_entries - def update_entries_with_ids(self, current_entries, user: KhojUser = None): + def update_entries_with_ids(self, current_entries, user: KhojUser = None, regenerate: bool = False): # Identify, mark and merge any new entries with previous entries with timer("Identify new or updated entries", logger): num_new_embeddings, num_deleted_embeddings = self.update_embeddings( @@ -115,6 +115,7 @@ def update_entries_with_ids(self, current_entries, user: KhojUser = None): DbEntry.EntrySource.GITHUB, key="compiled", logger=logger, + regenerate=regenerate, ) return num_new_embeddings, num_deleted_embeddings @@ -219,22 +220,37 @@ def get_files(self, repo_url: str, repo: GithubRepoConfig): def get_file_contents(self, file_url, decode=True): # Get text from each markdown file headers = {"Accept": "application/vnd.github.v3.raw"} - response = self.session.get(file_url, headers=headers, stream=True) - # Stop indexing on hitting rate limit - if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0": - raise ConnectionAbortedError("Github rate limit reached") - - content = "" if decode else b"" - for chunk in response.iter_content(chunk_size=2048): - if chunk: - try: - content += chunk.decode("utf-8") if decode else chunk - except Exception as e: - logger.error(f"Unable to decode chunk from {file_url}") - logger.error(e) - - return content + for attempt in range(3): + try: + # Retry on rate limit + if attempt > 2: + logger.error(f"Unable to download file {file_url} after 3 attempts") + break + + response = self.session.get(file_url, headers=headers, stream=True) + + # Stop indexing on hitting rate limit + if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0": + raise ConnectionAbortedError("Github rate limit reached") + + content = "" if decode else b"" + for chunk in response.iter_content(chunk_size=2048): + if chunk: + try: + content += chunk.decode("utf-8") if decode else chunk + except Exception as e: + logger.error(f"Unable to decode chunk from {file_url}") + logger.error(e) + + return content + except requests.exceptions.ChunkedEncodingError as e: + logger.error(f"Chunked encoding error while downloading {file_url}. Retrying...") + # Retry on chunked encoding error with exponential backoff approach + time.sleep(2**attempt) + + logger.error(f"Failed to download file {file_url} after 3 attempts") + return "" if decode else b"" @staticmethod def extract_markdown_entries(markdown_files): From 2a61ceb206e73b4aa1c31a10bf71d9d81cfcb5ae Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Tue, 6 May 2025 14:52:42 +0200 Subject: [PATCH 5/6] Update export script in package.json to clean output directory before building --- src/interface/web/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/web/package.json b/src/interface/web/package.json index ece5f5206..67606b091 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -9,7 +9,7 @@ "lint": "next lint", "collectstatic": "bash -c 'pushd ../../../ && source .venv/bin/activate && python3 src/khoj/manage.py collectstatic --noinput && deactivate && popd'", "cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'", - "export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", + "export": "rm -rf out/ ../../khoj/interface/built && yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", "ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic", "pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic", "watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'", From 6d4fc24b2f50b28c372afd2f46dd227201b8a92d Mon Sep 17 00:00:00 2001 From: Kevin Lundell Date: Sat, 10 May 2025 19:11:50 +0200 Subject: [PATCH 6/6] Removed unrelated things to GitHub integration --- src/interface/web/package.json | 2 +- src/khoj/configure.py | 2 - src/khoj/database/admin.py | 13 - src/khoj/database/models/__init__.py | 13 - src/khoj/main.py | 19 -- .../processor/content/github/github_sync.py | 25 -- src/khoj/utils/service_manager.py | 244 ------------------ 7 files changed, 1 insertion(+), 317 deletions(-) delete mode 100644 src/khoj/processor/content/github/github_sync.py delete mode 100644 src/khoj/utils/service_manager.py diff --git a/src/interface/web/package.json b/src/interface/web/package.json index 67606b091..ece5f5206 100644 --- a/src/interface/web/package.json +++ b/src/interface/web/package.json @@ -9,7 +9,7 @@ "lint": "next lint", "collectstatic": "bash -c 'pushd ../../../ && source .venv/bin/activate && python3 src/khoj/manage.py collectstatic --noinput && deactivate && popd'", "cicollectstatic": "bash -c 'pushd ../../../ && python3 src/khoj/manage.py collectstatic --noinput && popd'", - "export": "rm -rf out/ ../../khoj/interface/built && yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", + "export": "yarn build && cp -r out/ ../../khoj/interface/built && yarn collectstatic", "ciexport": "yarn build && cp -r out/ ../../khoj/interface/built && yarn cicollectstatic", "pypiciexport": "yarn build && cp -r out/ /opt/hostedtoolcache/Python/3.11.12/x64/lib/python3.11/site-packages/khoj/interface/compiled && yarn cicollectstatic", "watch": "nodemon --watch . --ext js,jsx,ts,tsx,css --ignore 'out/**/*' --exec 'yarn export'", diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 195ce860d..647a9fd25 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -328,8 +328,6 @@ def configure_routes(app): app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) - logger.info("🛣️ API Routes configured") - if not state.anonymous_mode: from khoj.routers.auth import auth_router diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index a94ec1c72..7297ce118 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -17,7 +17,6 @@ from khoj.database.models import ( Agent, AiModelApi, - BackgroundServiceConfig, ChatModel, ClientApplication, Conversation, @@ -424,15 +423,3 @@ def get_subscription_type(self, obj): get_subscription_type.short_description = "Subscription Type" # type: ignore get_subscription_type.admin_order_field = "user__subscription__type" # type: ignore - - -@admin.register(BackgroundServiceConfig) -class BackgroundServiceConfigAdmin(admin.ModelAdmin): - list_display = ("task_name", "task_interval", "task_last_run", "task_next_run", "task_is_enabled") - list_editable = ("task_interval", "task_is_enabled") - list_filter = ("task_is_enabled",) - search_fields = ("task_name", "task_id") - ordering = ("task_name",) - - def has_add_permission(self, request): - return False # ❌ disables the "Add" button in the admin panel because we don't want to add new tasks manually diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 29c1aff40..bd49aa8cd 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -365,19 +365,6 @@ class GithubConfig(DbBaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class BackgroundServiceConfig(models.Model): - class Meta: - verbose_name = "Background Service Config" - verbose_name_plural = "Background Service Configs" - - task_id = models.CharField(max_length=200) - task_name = models.CharField(max_length=200) - task_interval = models.IntegerField(default=3600) # interval in seconds (1 hour) - task_last_run = models.DateTimeField(null=True, default=None, blank=True) - task_next_run = models.DateTimeField(null=True, default=None, blank=True) - task_is_enabled = models.BooleanField(default=True) - - class GithubRepoConfig(DbBaseModel): name = models.CharField(max_length=200) owner = models.CharField(max_length=200) diff --git a/src/khoj/main.py b/src/khoj/main.py index 256d1ce4f..57794ebb0 100644 --- a/src/khoj/main.py +++ b/src/khoj/main.py @@ -16,7 +16,6 @@ from importlib.metadata import version from khoj.utils.helpers import in_debug_mode, is_env_var_true -from khoj.utils.service_manager import ServiceManager, BackgroundService # Ignore non-actionable warnings warnings.filterwarnings("ignore", message=r"snapshot_download.py has been made private", category=FutureWarning) @@ -102,9 +101,6 @@ SCHEDULE_LEADER_NAME = ProcessLock.Operation.SCHEDULE_LEADER -# Initialize Service Manager -service_manager = ServiceManager() - def shutdown_scheduler(): logger.info("🌑 Shutting down Khoj") @@ -113,27 +109,12 @@ def shutdown_scheduler(): logger.info("🔓 Schedule Leader released") ProcessLockAdapters.remove_process_lock(state.schedule_leader_process_lock) - # Stop all background services - service_manager.stop_all() - try: state.scheduler.shutdown() except Exception as e: logger.debug(f"Did not shutdown scheduler: {e}") -# Initialize Background Services -@app.on_event("startup") -def initialize_background_services(): - from khoj.processor.content.github.github_sync import github_sync_task - - # Register the GitHub sync task - service_manager.register_service(name="github_sync", interval=3600, fn=github_sync_task) # 1 hour - - # Start the background services - service_manager.start_all() - - def run(should_start_server=True): # Turn Tokenizers Parallelism Off. App does not support it. os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/src/khoj/processor/content/github/github_sync.py b/src/khoj/processor/content/github/github_sync.py deleted file mode 100644 index dd7b86271..000000000 --- a/src/khoj/processor/content/github/github_sync.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging - -from khoj.database.models import GithubConfig -from khoj.processor.content.github.github_to_entries import GithubToEntries - -logger = logging.getLogger(__name__) - - -def github_sync_task(): - """ - This module contains the GitHub sync task that runs periodically to sync - GitHub repositories with the Khoj database. - If the task returns False, it will stop running. - If the task returns True, it will continue running. - """ - - logger.info("Running GitHub sync...") - - # Fetch all GitHub configurations - configs = GithubConfig.objects.all() - if configs: - for config in configs: - GithubToEntries(config).process(files={}, user=config.user, regenerate=False) - - return True diff --git a/src/khoj/utils/service_manager.py b/src/khoj/utils/service_manager.py deleted file mode 100644 index 7e8a24760..000000000 --- a/src/khoj/utils/service_manager.py +++ /dev/null @@ -1,244 +0,0 @@ -import asyncio -import logging -import time -from typing import Callable, Dict, Optional - -from asgiref.sync import sync_to_async -from django.utils import timezone - -logger = logging.getLogger(__name__) - - -class BackgroundService: - - """ - A class to manage background services that run periodically. - Each service can be started, stopped, and registered with a name and interval. - """ - - def __init__(self, name: str, interval: int, fn: Callable, run_immediately: bool = True): - self.name = name - self.interval = interval - self.fn = fn - self.run_immediately = run_immediately - self.task: Optional[asyncio.Task] = None - self.should_continue = True - - from khoj.database.models import BackgroundServiceConfig - - self.config: BackgroundServiceConfig = None - - async def loadConfig(self) -> None: - """ - Load the configuration for the service from the database. - This is a placeholder for loading any necessary configuration. - """ - from khoj.database.models import BackgroundServiceConfig - - try: - # Remove old config - if self.config: - self.config = None - - self.config = await sync_to_async(BackgroundServiceConfig.objects.get)(task_name=self.name) - - self.should_continue = self.config.task_is_enabled - self.interval = self.config.task_interval - - if not self.config.task_is_enabled: - logger.info(f"[{self.name}] Service is disabled.") - self.run_immediately = False - self.should_continue = False - return - - will_save = False - - # Time checks - current_time = timezone.now().timestamp() - - if self.config.task_last_run and (current_time - self.config.task_last_run.timestamp()) < self.interval: - self.run_immediately = True - else: - will_save = True - - if self.config.task_next_run: - diff = self.config.task_next_run.timestamp() - current_time - if diff > 0: - self.run_immediately = False - - self.interval = diff - - else: - self.run_immediately = True - else: - self.config.task_next_run = timezone.now() + timezone.timedelta(seconds=self.interval) - will_save = True - - if will_save: - await sync_to_async(self.config.save)() - - logger.info(f"[{self.name}] Loaded config: {self.config}") - except BackgroundServiceConfig.DoesNotExist: - logger.warning(f"[{self.name}] No config found. Creating default entry.") - await sync_to_async(BackgroundServiceConfig.objects.create)( - task_id=f"default-{self.name}", - task_name=self.name, - task_interval=self.interval, - task_last_run=timezone.now(), - task_next_run=timezone.now() + timezone.timedelta(seconds=self.interval), - task_is_enabled=True, - ) - except Exception as e: - logger.error(f"[{self.name}] Error loading config: {e}") - self.should_continue = False - self.run_immediately = False - logger.info(f"[{self.name}] Service is disabled due to error.") - - async def updateConfig(self) -> bool: - """ - Update the configuration for the service in the database. - This is a placeholder for updating any necessary configuration. - """ - from khoj.database.models import BackgroundServiceConfig - - result = True - try: - self.config = await sync_to_async(BackgroundServiceConfig.objects.get)(task_name=self.name) - - self.config.task_last_run = timezone.now() - self.config.task_next_run = timezone.now() + timezone.timedelta(seconds=self.interval) - self.interval = self.config.task_interval - - if not self.config.task_is_enabled: - result = False - - await sync_to_async(self.config.save)() - except Exception as e: - logger.error(f"[{self.name}] Error updating config: {e}") - - return result - - async def _safe_call(self) -> bool: - try: - # Check if the service function is async or sync - if asyncio.iscoroutinefunction(self.fn): - result = await self.fn() - else: - result = await sync_to_async(self.fn)() - - result = bool(result) # Treat None or False as stop signal - - if result: - # if the function returns True, the configuration is updated - # if the function returns False, the service was disabled - # and the configuration is updated - if not await self.updateConfig(): - result = False - - return result - except Exception as e: - logger.error(f"[{self.name}] Error: {e}") - return True # Keep trying even after failure, unless you want to halt - - async def start(self): - if not self.config: - await self.loadConfig() - - if self.run_immediately: - self.should_continue = await self._safe_call() - - while self.should_continue: - await asyncio.sleep(self.interval) - if not self.should_continue: - break - self.should_continue = await self._safe_call() - - logger.info(f"[{self.name}] Stopped.") - if self.task: - self.task.cancel() - self.task = None - logger.info(f"[{self.name}] Task cancelled.") - - async def stop(self): - if self.task: - self.task.cancel() - self.should_continue = False - logger.info(f"[{self.name}] Stopped.") - self.task = None - - -class ServiceManager: - def __init__(self): - self.services: Dict[str, BackgroundService] = {} - - self.register_service( - name="service_manager", interval=60, fn=self._service_manager_task, run_immediately=False # 1 minute - ) - - async def _service_manager_task(self) -> bool: - """ - This task runs every minute to check the status of all registered services. - It can be used to perform any necessary maintenance or updates. - """ - - logger.info("Service Manager Task running...") - for service in self.services.values(): - if service.name == "service_manager": - continue - - await service.loadConfig() - - if service.task and not service.task.done(): - if service.config and service.config.task_is_enabled: - logger.info(f"Service '{service.name}' is running and should be.") - elif service.config and not service.config.task_is_enabled: - logger.info(f"Service '{service.name}' is running, but should not be.") - await service.stop() - else: - logger.info(f"Service '{service.name}' is running, but no config found.") - - elif service.config and service.config.task_is_enabled and not service.task: - logger.info(f"Service '{service.name}' is not running, but should be.") - - service.task = asyncio.create_task(service.start()) - elif service.config and not service.config.task_is_enabled and not service.task: - logger.info(f"Service '{service.name}' is not running and should not be.") - # service.config = None - - return True - - def register_service( - self, name: str, interval: int, fn: Callable, run_immediately: bool = True - ) -> BackgroundService: - if name in self.services: - raise ValueError(f"Service '{name}' already registered.") - service = BackgroundService(name, interval, fn, run_immediately) - self.services[name] = service - - return service - - def get_service(self, name: str) -> Optional[BackgroundService]: - return self.services.get(name) - - def start_all(self) -> None: - for service in self.services.values(): - service.task = asyncio.create_task(service.start()) - - def delete_service(self, name: str) -> Optional[BackgroundService]: - if name in self.services: - service = self.services.pop(name) - if service.task: - service.task.cancel() - logger.info(f"Service '{name}' stopped.") - return service - else: - logger.warning(f"Service '{name}' not found.") - return None - - def stop_all(self) -> None: - for service in self.services.values(): - if service.task: - service.task.cancel() - logger.info(f"Service '{service.name}' stopped.") - self.services.clear() - logger.info("All services stopped and cleared.")