diff --git a/ui/__tests__/rag-documents.test.ts b/ui/__tests__/rag-documents.test.ts new file mode 100644 index 0000000..efd29f7 --- /dev/null +++ b/ui/__tests__/rag-documents.test.ts @@ -0,0 +1,32 @@ +import { + DEFAULT_CHROMA_PATH, + buildInternalApiUrl, + getChromaPath, +} from '@/utils/server/rag-documents'; + +import { afterEach, describe, expect, it } from 'vitest'; + +describe('rag document helpers', () => { + afterEach(() => { + delete process.env.CHROMA_PATH; + }); + + it('uses the Docker Chroma default when no override is configured', () => { + expect(getChromaPath()).toBe(DEFAULT_CHROMA_PATH); + }); + + it('uses CHROMA_PATH when configured', () => { + process.env.CHROMA_PATH = 'http://custom-chroma:8000'; + + expect(getChromaPath()).toBe('http://custom-chroma:8000'); + }); + + it('builds internal API URLs from the current request origin', () => { + expect( + buildInternalApiUrl( + 'https://example.com/api/rag-chat?conversation=abc', + '/api/fetch-documents', + ), + ).toBe('https://example.com/api/fetch-documents'); + }); +}); diff --git a/ui/pages/api/fetch-documents.ts b/ui/pages/api/fetch-documents.ts index 9304e48..8dc754e 100644 --- a/ui/pages/api/fetch-documents.ts +++ b/ui/pages/api/fetch-documents.ts @@ -1,23 +1,41 @@ -import type { NextApiRequest, NextApiResponse } from "next"; -import { ChromaClient, TransformersEmbeddingFunction } from "chromadb"; +import type { NextApiRequest, NextApiResponse } from 'next'; -export default async function handler(req: NextApiRequest, res: NextApiResponse) { +import { getChromaPath } from '@/utils/server/rag-documents'; + +import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb'; + +export default async function handler( + req: NextApiRequest, + res: NextApiResponse, +) { try { + if (req.method !== 'POST') { + res.setHeader('Allow', 'POST'); + return res.status(405).json({ error: 'Method not allowed' }); + } + const client = new ChromaClient({ - path: "http://chroma-server:8000", + path: getChromaPath(), }); const query = req.body.input; + if (typeof query !== 'string' || query.length === 0) { + return res.status(400).json({ error: 'Missing query input' }); + } + const embedder = new TransformersEmbeddingFunction(); - const collection = await client.getOrCreateCollection({ name: "default-collection", embeddingFunction: embedder }); + const collection = await client.getOrCreateCollection({ + name: 'default-collection', + embeddingFunction: embedder, + }); - // query the collection - const results = await collection.query({ - nResults: 4, - queryTexts: [query] - }) + // query the collection + const results = await collection.query({ + nResults: 4, + queryTexts: [query], + }); res.status(200).json(results); } catch (error) { @@ -29,4 +47,4 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } res.status(500).json({ error: 'An unexpected error occurred :(' }); } -} \ No newline at end of file +} diff --git a/ui/pages/api/rag-chat.ts b/ui/pages/api/rag-chat.ts index ce84d67..eb2079c 100644 --- a/ui/pages/api/rag-chat.ts +++ b/ui/pages/api/rag-chat.ts @@ -1,6 +1,6 @@ import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const'; import { OpenAIError, OpenAIStream } from '@/utils/server'; -import { codeBlock, oneLine } from 'common-tags' +import { buildInternalApiUrl } from '@/utils/server/rag-documents'; import { ChatBody, Message } from '@/types/chat'; @@ -9,46 +9,51 @@ import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module import tiktokenModel from '@dqbd/tiktoken/encoders/cl100k_base.json'; import { Tiktoken, init } from '@dqbd/tiktoken/lite/init'; +import { codeBlock, oneLine } from 'common-tags'; export const config = { runtime: 'edge', }; // Function to fetch and format documents -async function fetchAndFormatDocuments(lastMessageContent: string) { +async function fetchAndFormatDocuments( + lastMessageContent: string, + requestUrl: string, +) { try { - console.log("fetching documents") - const response = await fetch('http://localhost:3000/api/fetch-documents', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ input: lastMessageContent }), - }); - + console.log('fetching documents'); + const response = await fetch( + buildInternalApiUrl(requestUrl, '/api/fetch-documents'), + { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ input: lastMessageContent }), + }, + ); + if (!response.ok) { throw new Error(`Error fetching documents: ${response.statusText}`); } const data = await response.json(); - const result = data.metadatas[0].map((metadata: any, index: number) => { - return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`; - }).join(''); + const result = data.metadatas[0] + .map((metadata: any, index: number) => { + return `Source ${index + 1}) Title: ${metadata.title}, Page: ${ + metadata.page + }, Content: ${data.documents[0][index]}\n`; + }) + .join(''); console.log(result); return result; - } catch (error) { console.error('Error fetching and formatting documents:', error); throw error; // You may want to throw a more specific error object here } } - - - - const handler = async (req: Request): Promise => { - try { const { model, messages, key, prompt, temperature } = (await req.json()) as ChatBody; @@ -85,8 +90,11 @@ const handler = async (req: Request): Promise => { const lastMessage = messages[messages.length - 1]; - const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content); - + const relevantDocuments = await fetchAndFormatDocuments( + lastMessage.content, + req.url, + ); + let temperatureToUse = temperature; if (temperatureToUse == null) { temperatureToUse = DEFAULT_TEMPERATURE; @@ -97,22 +105,20 @@ const handler = async (req: Request): Promise => { let tokenCount = prompt_tokens.length; let messagesToSend: Message[] = []; - encoding.free(); console.log(model, promptToSend, temperatureToUse, key, messagesToSend); - - messagesToSend = [ + messagesToSend = [ { - role: "user", + role: 'user', content: codeBlock` Here is the relevant documentation: ${relevantDocuments} `, }, { - role: "user", + role: 'user', content: codeBlock` ${oneLine` Answer my next question using only the above documentation. @@ -135,19 +141,18 @@ const handler = async (req: Request): Promise => { `, }, { - role: "user", + role: 'user', content: codeBlock` Here is my question: ${oneLine`${lastMessage.content}`} `, }, - ] - + ]; const stream = await OpenAIStream( model, promptToSend, - 0, + temperatureToUse, key, messagesToSend, ); diff --git a/ui/utils/server/rag-documents.ts b/ui/utils/server/rag-documents.ts new file mode 100644 index 0000000..8aaf383 --- /dev/null +++ b/ui/utils/server/rag-documents.ts @@ -0,0 +1,7 @@ +export const DEFAULT_CHROMA_PATH = 'http://chroma-server:8000'; + +export const getChromaPath = () => + process.env.CHROMA_PATH || DEFAULT_CHROMA_PATH; + +export const buildInternalApiUrl = (requestUrl: string, pathname: string) => + new URL(pathname, requestUrl).toString();