Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ui/__tests__/rag-documents.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import {
DEFAULT_CHROMA_PATH,
buildInternalApiUrl,
getChromaPath,
} from '@/utils/server/rag-documents';

import { afterEach, describe, expect, it } from 'vitest';

describe('rag document helpers', () => {
afterEach(() => {
delete process.env.CHROMA_PATH;
});

it('uses the Docker Chroma default when no override is configured', () => {
expect(getChromaPath()).toBe(DEFAULT_CHROMA_PATH);
});

it('uses CHROMA_PATH when configured', () => {
process.env.CHROMA_PATH = 'http://custom-chroma:8000';

expect(getChromaPath()).toBe('http://custom-chroma:8000');
});

it('builds internal API URLs from the current request origin', () => {
expect(
buildInternalApiUrl(
'https://example.com/api/rag-chat?conversation=abc',
'/api/fetch-documents',
),
).toBe('https://example.com/api/fetch-documents');
});
});
40 changes: 29 additions & 11 deletions ui/pages/api/fetch-documents.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
import type { NextApiRequest, NextApiResponse } from "next";
import { ChromaClient, TransformersEmbeddingFunction } from "chromadb";
import type { NextApiRequest, NextApiResponse } from 'next';

export default async function handler(req: NextApiRequest, res: NextApiResponse) {
import { getChromaPath } from '@/utils/server/rag-documents';

import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb';

export default async function handler(
req: NextApiRequest,
res: NextApiResponse,
) {
try {
if (req.method !== 'POST') {
res.setHeader('Allow', 'POST');
return res.status(405).json({ error: 'Method not allowed' });
}

const client = new ChromaClient({
path: "http://chroma-server:8000",
path: getChromaPath(),
});

const query = req.body.input;

if (typeof query !== 'string' || query.length === 0) {
return res.status(400).json({ error: 'Missing query input' });
}

const embedder = new TransformersEmbeddingFunction();

const collection = await client.getOrCreateCollection({ name: "default-collection", embeddingFunction: embedder });
const collection = await client.getOrCreateCollection({
name: 'default-collection',
embeddingFunction: embedder,
});

// query the collection
const results = await collection.query({
nResults: 4,
queryTexts: [query]
})
// query the collection
const results = await collection.query({
nResults: 4,
queryTexts: [query],
});

res.status(200).json(results);
} catch (error) {
Expand All @@ -29,4 +47,4 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
res.status(500).json({ error: 'An unexpected error occurred :(' });
}
}
}
63 changes: 34 additions & 29 deletions ui/pages/api/rag-chat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const';
import { OpenAIError, OpenAIStream } from '@/utils/server';
import { codeBlock, oneLine } from 'common-tags'
import { buildInternalApiUrl } from '@/utils/server/rag-documents';

import { ChatBody, Message } from '@/types/chat';

Expand All @@ -9,46 +9,51 @@ import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module

import tiktokenModel from '@dqbd/tiktoken/encoders/cl100k_base.json';
import { Tiktoken, init } from '@dqbd/tiktoken/lite/init';
import { codeBlock, oneLine } from 'common-tags';

export const config = {
runtime: 'edge',
};

// Function to fetch and format documents
async function fetchAndFormatDocuments(lastMessageContent: string) {
async function fetchAndFormatDocuments(
lastMessageContent: string,
requestUrl: string,
) {
try {
console.log("fetching documents")
const response = await fetch('http://localhost:3000/api/fetch-documents', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ input: lastMessageContent }),
});

console.log('fetching documents');
const response = await fetch(
buildInternalApiUrl(requestUrl, '/api/fetch-documents'),
{
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ input: lastMessageContent }),
},
);

if (!response.ok) {
throw new Error(`Error fetching documents: ${response.statusText}`);
}

const data = await response.json();
const result = data.metadatas[0].map((metadata: any, index: number) => {
return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`;
}).join('');
const result = data.metadatas[0]
.map((metadata: any, index: number) => {
return `Source ${index + 1}) Title: ${metadata.title}, Page: ${
metadata.page
}, Content: ${data.documents[0][index]}\n`;
})
.join('');

console.log(result);

return result;

} catch (error) {
console.error('Error fetching and formatting documents:', error);
throw error; // You may want to throw a more specific error object here
}
}





const handler = async (req: Request): Promise<Response> => {

try {
const { model, messages, key, prompt, temperature } =
(await req.json()) as ChatBody;
Expand Down Expand Up @@ -85,8 +90,11 @@ const handler = async (req: Request): Promise<Response> => {

const lastMessage = messages[messages.length - 1];

const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content);

const relevantDocuments = await fetchAndFormatDocuments(
lastMessage.content,
req.url,
);

let temperatureToUse = temperature;
if (temperatureToUse == null) {
temperatureToUse = DEFAULT_TEMPERATURE;
Expand All @@ -97,22 +105,20 @@ const handler = async (req: Request): Promise<Response> => {
let tokenCount = prompt_tokens.length;
let messagesToSend: Message[] = [];


encoding.free();

console.log(model, promptToSend, temperatureToUse, key, messagesToSend);


messagesToSend = [
messagesToSend = [
{
role: "user",
role: 'user',
content: codeBlock`
Here is the relevant documentation:
${relevantDocuments}
`,
},
{
role: "user",
role: 'user',
content: codeBlock`
${oneLine`
Answer my next question using only the above documentation.
Expand All @@ -135,19 +141,18 @@ const handler = async (req: Request): Promise<Response> => {
`,
},
{
role: "user",
role: 'user',
content: codeBlock`
Here is my question:
${oneLine`${lastMessage.content}`}
`,
},
]

];

const stream = await OpenAIStream(
model,
promptToSend,
0,
temperatureToUse,
key,
messagesToSend,
);
Expand Down
7 changes: 7 additions & 0 deletions ui/utils/server/rag-documents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export const DEFAULT_CHROMA_PATH = 'http://chroma-server:8000';

export const getChromaPath = () =>
process.env.CHROMA_PATH || DEFAULT_CHROMA_PATH;

export const buildInternalApiUrl = (requestUrl: string, pathname: string) =>
new URL(pathname, requestUrl).toString();