Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions ui/__tests__/scientific-rag.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import {
buildCitationKey,
buildScientificMetadata,
detectScientificSection,
formatRetrievedDocument,
formatRetrievedDocuments,
} from '@/utils/server/scientific-rag';
import { describe, expect, it } from 'vitest';

describe('scientific-rag helpers', () => {
it('detects scientific sections from document chunks', () => {
expect(detectScientificSection('Abstract\nThis paper studies retrieval.')).toBe(
'abstract',
);
expect(detectScientificSection('METHODS\nWe used a benchmark.')).toBe(
'methods',
);
expect(
detectScientificSection('Materials and Methods\nWe collected samples.'),
).toBe('materials and methods');
});

it('builds stable citation keys', () => {
expect(
buildCitationKey({
title: 'Scientific RAG for Papers!',
page: 4,
pageChunkIndex: 2,
}),
).toBe('scientific-rag-for-papers:p4:c3');
});

it('builds metadata with title fallback and section', () => {
const metadata = buildScientificMetadata(
{
pageContent: 'Results\nThe model improved citation accuracy.',
metadata: {
loc: { pageNumber: 7 },
pdf: { info: { Title: '' } },
source: '/tmp/paper.pdf',
},
},
'paper.pdf',
0,
);

expect(metadata).toMatchObject({
title: 'paper.pdf',
page: 7,
section: 'results',
citationKey: 'paper-pdf:p7:c1',
});
});

it('can build page-local citation keys when global chunk order differs', () => {
const metadata = buildScientificMetadata(
{
pageContent: 'Discussion\nThe citation key should be local to a page.',
metadata: {
loc: { pageNumber: 9 },
pdf: { info: { Title: 'Long Paper' } },
source: '/tmp/long-paper.pdf',
},
},
'long-paper.pdf',
14,
1,
);

expect(metadata).toMatchObject({
chunkIndex: 14,
pageChunkIndex: 1,
citationKey: 'long-paper:p9:c2',
});
});

it('formats retrieved documents with citation metadata', () => {
expect(
formatRetrievedDocument({
content: 'Citation-aware answer context.',
metadata: {
title: 'Paper',
page: 2,
section: 'discussion',
citationKey: 'paper:p2:c1',
},
distance: 0.12345,
index: 0,
}),
).toContain('Source 1 [paper:p2:c1]');
});

it('formats Chroma retrieval results defensively', () => {
expect(
formatRetrievedDocuments({
documents: [['Document context.']],
metadatas: [[{ citationKey: 'paper:p1:c1', title: 'Paper', page: 1 }]],
distances: [[0.42]],
}),
).toContain('Source 1 [paper:p1:c1]');

expect(formatRetrievedDocuments({ documents: [[]] })).toBe('');
expect(formatRetrievedDocuments({ documents: undefined })).toBe('');
});
});
38 changes: 27 additions & 11 deletions ui/pages/api/fetch-documents.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import type { NextApiRequest, NextApiResponse } from "next";
import { ChromaClient, TransformersEmbeddingFunction } from "chromadb";
import type { NextApiRequest, NextApiResponse } from 'next';
import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb';

export default async function handler(req: NextApiRequest, res: NextApiResponse) {
try {
if (req.method !== 'POST') {
return res.status(405).end();
}

const client = new ChromaClient({
path: "http://chroma-server:8000",
path: process.env.CHROMA_PATH || 'http://chroma-server:8000',
});

const query = req.body.input;
const query = typeof req.body.input === 'string' ? req.body.input.trim() : '';
const requestedResults = Number(req.body.nResults ?? 6);
const nResults = Number.isFinite(requestedResults)
? Math.min(Math.max(Math.trunc(requestedResults), 1), 10)
: 6;

if (!query) {
return res.status(400).json({ error: 'Missing retrieval query' });
}

const embedder = new TransformersEmbeddingFunction();

const collection = await client.getOrCreateCollection({ name: "default-collection", embeddingFunction: embedder });
const collection = await client.getOrCreateCollection({
name: 'default-collection',
embeddingFunction: embedder,
});

// query the collection
const results = await collection.query({
nResults: 4,
queryTexts: [query]
})
// query the collection
const results = await collection.query({
nResults,
queryTexts: [query],
include: ['documents', 'metadatas', 'distances'] as any,
});

res.status(200).json(results);
} catch (error) {
Expand All @@ -29,4 +45,4 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse)
}
res.status(500).json({ error: 'An unexpected error occurred :(' });
}
}
}
60 changes: 37 additions & 23 deletions ui/pages/api/inject-documents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ import type { NextApiRequest, NextApiResponse } from 'next';
import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb';
import { IncomingForm } from 'formidable';
import { PDFLoader } from 'langchain/document_loaders/fs/pdf';
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter';

import path from 'path';
import { v4 as uuidv4 } from 'uuid';

import {
SCIENTIFIC_TEXT_SEPARATORS,
buildScientificMetadata,
type ScientificDocument,
} from '@/utils/server/scientific-rag';

export const config = {
api: {
bodyParser: false,
Expand All @@ -33,20 +39,25 @@ export default async function handler(
path: process.env.CHROMA_PATH || 'http://chroma-server:8000',
});

const loader = new PDFLoader(files.pdf[0].filepath);
const pdf = files.pdf;
const pdfFile = Array.isArray(pdf) ? pdf[0] : pdf;

const originalDocs = await loader.load();
if (!pdfFile || typeof pdfFile.filepath !== 'string') {
return res.status(400).json({ error: 'Missing PDF upload' });
}

console.log(JSON.stringify(originalDocs));
const loader = new PDFLoader(pdfFile.filepath);

const originalDocs = await loader.load();

const splitter = new RecursiveCharacterTextSplitter({
chunkSize: 500,
chunkOverlap: 100,
});
chunkSize: 900,
chunkOverlap: 180,
separators: SCIENTIFIC_TEXT_SEPARATORS,
});

const docs = await splitter.splitDocuments(originalDocs);

// Process the documents and perform other logic
const { ids, metadatas, documentContents } = processDocuments(docs);

Expand Down Expand Up @@ -75,27 +86,30 @@ export default async function handler(
}
}

function processDocuments(docs: any) {
const ids = [];
function processDocuments(docs: ScientificDocument[]) {
const ids: string[] = [];
const metadatas = [];
const documentContents = [];
const documentContents: string[] = [];
const pageChunkCounts = new Map<string, number>();

for (const document of docs) {
for (let index = 0; index < docs.length; index += 1) {
const document = docs[index];
// Generate an ID for each document, or use some existing unique identifier
const id = uuidv4();
ids.push(id);

const fallbackTitle = path.basename(document.metadata.source);
const titleFromMetadata = document.metadata.pdf.info.Title;

const title = titleFromMetadata && titleFromMetadata.length > 0 ? titleFromMetadata : fallbackTitle;


const metadata = {
title: title,
page: document.metadata.loc.pageNumber, // Define this function to extract chapter info
source: document.metadata.source, // Define this function to extract verse info
};
const fallbackTitle = path.basename(document.metadata.source ?? 'document.pdf');
const page = document.metadata.loc?.pageNumber ?? 'unknown';
const pageKey = `${document.metadata.source ?? fallbackTitle}:${page}`;
const pageChunkIndex = pageChunkCounts.get(pageKey) ?? 0;
pageChunkCounts.set(pageKey, pageChunkIndex + 1);

const metadata = buildScientificMetadata(
document,
fallbackTitle,
index,
pageChunkIndex,
);
metadatas.push(metadata);

// Add the page content to the documents array
Expand Down
44 changes: 26 additions & 18 deletions ui/pages/api/rag-chat.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const';
import { OpenAIError, OpenAIStream } from '@/utils/server';
import { codeBlock, oneLine } from 'common-tags'
import { codeBlock, oneLine } from 'common-tags';

import { ChatBody, Message } from '@/types/chat';
import { formatRetrievedDocuments } from '@/utils/server/scientific-rag';

// @ts-expect-error
import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module';
Expand All @@ -15,31 +16,29 @@ export const config = {
};

// Function to fetch and format documents
async function fetchAndFormatDocuments(lastMessageContent: string) {
async function fetchAndFormatDocuments(
baseUrl: string,
lastMessageContent: string,
) {
try {
console.log("fetching documents")
const response = await fetch('http://localhost:3000/api/fetch-documents', {
const response = await fetch(`${baseUrl}/api/fetch-documents`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ input: lastMessageContent }),
body: JSON.stringify({ input: lastMessageContent, nResults: 6 }),
});

if (!response.ok) {
throw new Error(`Error fetching documents: ${response.statusText}`);
}

const data = await response.json();
const result = data.metadatas[0].map((metadata: any, index: number) => {
return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`;
}).join('');
const result = formatRetrievedDocuments(data);

console.log(result);

return result;
return result || 'No relevant documents were retrieved.';

} catch (error) {
console.error('Error fetching and formatting documents:', error);
throw error; // You may want to throw a more specific error object here
throw error;
}
}

Expand All @@ -64,7 +63,7 @@ const handler = async (req: Request): Promise<Response> => {
${oneLine`
You are a very enthusiastic AI assistant who loves
to help people! Given the following information from
relevant documentation, answer the user's question using
relevant scientific documentation, answer the user's question using
only that information, outputted in markdown format.
`}

Expand All @@ -75,7 +74,7 @@ const handler = async (req: Request): Promise<Response> => {
`}

${oneLine`
Always include citations from the documentation.
Every factual claim must include citation keys from the documentation.
`}
`;

Expand All @@ -85,7 +84,11 @@ const handler = async (req: Request): Promise<Response> => {

const lastMessage = messages[messages.length - 1];

const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content);
const baseUrl = new URL(req.url).origin;
const relevantDocuments = await fetchAndFormatDocuments(
baseUrl,
lastMessage.content,
);

let temperatureToUse = temperature;
if (temperatureToUse == null) {
Expand All @@ -100,9 +103,6 @@ const handler = async (req: Request): Promise<Response> => {

encoding.free();

console.log(model, promptToSend, temperatureToUse, key, messagesToSend);


messagesToSend = [
{
role: "user",
Expand All @@ -121,6 +121,14 @@ const handler = async (req: Request): Promise<Response> => {
${oneLine`
- Do not make up answers that are not provided in the documentation.
`}
${oneLine`
- Cite sources using the exact citation keys shown in square brackets,
for example [paper-title:p3:c2].
`}
${oneLine`
- Prefer sources with lower retrieval distance when multiple sources
contain similar information.
`}
${oneLine`
- If you are unsure and the answer is not explicitly written
in the documentation context, say
Expand Down
Loading