diff --git a/ui/__tests__/scientific-rag.test.ts b/ui/__tests__/scientific-rag.test.ts new file mode 100644 index 0000000..8e17da3 --- /dev/null +++ b/ui/__tests__/scientific-rag.test.ts @@ -0,0 +1,105 @@ +import { + buildCitationKey, + buildScientificMetadata, + detectScientificSection, + formatRetrievedDocument, + formatRetrievedDocuments, +} from '@/utils/server/scientific-rag'; +import { describe, expect, it } from 'vitest'; + +describe('scientific-rag helpers', () => { + it('detects scientific sections from document chunks', () => { + expect(detectScientificSection('Abstract\nThis paper studies retrieval.')).toBe( + 'abstract', + ); + expect(detectScientificSection('METHODS\nWe used a benchmark.')).toBe( + 'methods', + ); + expect( + detectScientificSection('Materials and Methods\nWe collected samples.'), + ).toBe('materials and methods'); + }); + + it('builds stable citation keys', () => { + expect( + buildCitationKey({ + title: 'Scientific RAG for Papers!', + page: 4, + pageChunkIndex: 2, + }), + ).toBe('scientific-rag-for-papers:p4:c3'); + }); + + it('builds metadata with title fallback and section', () => { + const metadata = buildScientificMetadata( + { + pageContent: 'Results\nThe model improved citation accuracy.', + metadata: { + loc: { pageNumber: 7 }, + pdf: { info: { Title: '' } }, + source: '/tmp/paper.pdf', + }, + }, + 'paper.pdf', + 0, + ); + + expect(metadata).toMatchObject({ + title: 'paper.pdf', + page: 7, + section: 'results', + citationKey: 'paper-pdf:p7:c1', + }); + }); + + it('can build page-local citation keys when global chunk order differs', () => { + const metadata = buildScientificMetadata( + { + pageContent: 'Discussion\nThe citation key should be local to a page.', + metadata: { + loc: { pageNumber: 9 }, + pdf: { info: { Title: 'Long Paper' } }, + source: '/tmp/long-paper.pdf', + }, + }, + 'long-paper.pdf', + 14, + 1, + ); + + expect(metadata).toMatchObject({ + chunkIndex: 14, + pageChunkIndex: 1, + citationKey: 'long-paper:p9:c2', + }); + }); + + it('formats retrieved documents with citation metadata', () => { + expect( + formatRetrievedDocument({ + content: 'Citation-aware answer context.', + metadata: { + title: 'Paper', + page: 2, + section: 'discussion', + citationKey: 'paper:p2:c1', + }, + distance: 0.12345, + index: 0, + }), + ).toContain('Source 1 [paper:p2:c1]'); + }); + + it('formats Chroma retrieval results defensively', () => { + expect( + formatRetrievedDocuments({ + documents: [['Document context.']], + metadatas: [[{ citationKey: 'paper:p1:c1', title: 'Paper', page: 1 }]], + distances: [[0.42]], + }), + ).toContain('Source 1 [paper:p1:c1]'); + + expect(formatRetrievedDocuments({ documents: [[]] })).toBe(''); + expect(formatRetrievedDocuments({ documents: undefined })).toBe(''); + }); +}); diff --git a/ui/pages/api/fetch-documents.ts b/ui/pages/api/fetch-documents.ts index 9304e48..08decf3 100644 --- a/ui/pages/api/fetch-documents.ts +++ b/ui/pages/api/fetch-documents.ts @@ -1,23 +1,39 @@ -import type { NextApiRequest, NextApiResponse } from "next"; -import { ChromaClient, TransformersEmbeddingFunction } from "chromadb"; +import type { NextApiRequest, NextApiResponse } from 'next'; +import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb'; export default async function handler(req: NextApiRequest, res: NextApiResponse) { try { + if (req.method !== 'POST') { + return res.status(405).end(); + } + const client = new ChromaClient({ - path: "http://chroma-server:8000", + path: process.env.CHROMA_PATH || 'http://chroma-server:8000', }); - const query = req.body.input; + const query = typeof req.body.input === 'string' ? req.body.input.trim() : ''; + const requestedResults = Number(req.body.nResults ?? 6); + const nResults = Number.isFinite(requestedResults) + ? Math.min(Math.max(Math.trunc(requestedResults), 1), 10) + : 6; + + if (!query) { + return res.status(400).json({ error: 'Missing retrieval query' }); + } const embedder = new TransformersEmbeddingFunction(); - const collection = await client.getOrCreateCollection({ name: "default-collection", embeddingFunction: embedder }); + const collection = await client.getOrCreateCollection({ + name: 'default-collection', + embeddingFunction: embedder, + }); - // query the collection - const results = await collection.query({ - nResults: 4, - queryTexts: [query] - }) + // query the collection + const results = await collection.query({ + nResults, + queryTexts: [query], + include: ['documents', 'metadatas', 'distances'] as any, + }); res.status(200).json(results); } catch (error) { @@ -29,4 +45,4 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } res.status(500).json({ error: 'An unexpected error occurred :(' }); } -} \ No newline at end of file +} diff --git a/ui/pages/api/inject-documents.ts b/ui/pages/api/inject-documents.ts index 532a635..1dbf39e 100644 --- a/ui/pages/api/inject-documents.ts +++ b/ui/pages/api/inject-documents.ts @@ -3,11 +3,17 @@ import type { NextApiRequest, NextApiResponse } from 'next'; import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb'; import { IncomingForm } from 'formidable'; import { PDFLoader } from 'langchain/document_loaders/fs/pdf'; -import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import { RecursiveCharacterTextSplitter } from 'langchain/text_splitter'; import path from 'path'; import { v4 as uuidv4 } from 'uuid'; +import { + SCIENTIFIC_TEXT_SEPARATORS, + buildScientificMetadata, + type ScientificDocument, +} from '@/utils/server/scientific-rag'; + export const config = { api: { bodyParser: false, @@ -33,20 +39,25 @@ export default async function handler( path: process.env.CHROMA_PATH || 'http://chroma-server:8000', }); - const loader = new PDFLoader(files.pdf[0].filepath); + const pdf = files.pdf; + const pdfFile = Array.isArray(pdf) ? pdf[0] : pdf; - const originalDocs = await loader.load(); + if (!pdfFile || typeof pdfFile.filepath !== 'string') { + return res.status(400).json({ error: 'Missing PDF upload' }); + } - console.log(JSON.stringify(originalDocs)); + const loader = new PDFLoader(pdfFile.filepath); + const originalDocs = await loader.load(); const splitter = new RecursiveCharacterTextSplitter({ - chunkSize: 500, - chunkOverlap: 100, - }); + chunkSize: 900, + chunkOverlap: 180, + separators: SCIENTIFIC_TEXT_SEPARATORS, + }); const docs = await splitter.splitDocuments(originalDocs); - + // Process the documents and perform other logic const { ids, metadatas, documentContents } = processDocuments(docs); @@ -75,27 +86,30 @@ export default async function handler( } } -function processDocuments(docs: any) { - const ids = []; +function processDocuments(docs: ScientificDocument[]) { + const ids: string[] = []; const metadatas = []; - const documentContents = []; + const documentContents: string[] = []; + const pageChunkCounts = new Map(); - for (const document of docs) { + for (let index = 0; index < docs.length; index += 1) { + const document = docs[index]; // Generate an ID for each document, or use some existing unique identifier const id = uuidv4(); ids.push(id); - const fallbackTitle = path.basename(document.metadata.source); - const titleFromMetadata = document.metadata.pdf.info.Title; - - const title = titleFromMetadata && titleFromMetadata.length > 0 ? titleFromMetadata : fallbackTitle; - - - const metadata = { - title: title, - page: document.metadata.loc.pageNumber, // Define this function to extract chapter info - source: document.metadata.source, // Define this function to extract verse info - }; + const fallbackTitle = path.basename(document.metadata.source ?? 'document.pdf'); + const page = document.metadata.loc?.pageNumber ?? 'unknown'; + const pageKey = `${document.metadata.source ?? fallbackTitle}:${page}`; + const pageChunkIndex = pageChunkCounts.get(pageKey) ?? 0; + pageChunkCounts.set(pageKey, pageChunkIndex + 1); + + const metadata = buildScientificMetadata( + document, + fallbackTitle, + index, + pageChunkIndex, + ); metadatas.push(metadata); // Add the page content to the documents array diff --git a/ui/pages/api/rag-chat.ts b/ui/pages/api/rag-chat.ts index ce84d67..5d42eb6 100644 --- a/ui/pages/api/rag-chat.ts +++ b/ui/pages/api/rag-chat.ts @@ -1,8 +1,9 @@ import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const'; import { OpenAIError, OpenAIStream } from '@/utils/server'; -import { codeBlock, oneLine } from 'common-tags' +import { codeBlock, oneLine } from 'common-tags'; import { ChatBody, Message } from '@/types/chat'; +import { formatRetrievedDocuments } from '@/utils/server/scientific-rag'; // @ts-expect-error import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module'; @@ -15,13 +16,15 @@ export const config = { }; // Function to fetch and format documents -async function fetchAndFormatDocuments(lastMessageContent: string) { +async function fetchAndFormatDocuments( + baseUrl: string, + lastMessageContent: string, +) { try { - console.log("fetching documents") - const response = await fetch('http://localhost:3000/api/fetch-documents', { + const response = await fetch(`${baseUrl}/api/fetch-documents`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ input: lastMessageContent }), + body: JSON.stringify({ input: lastMessageContent, nResults: 6 }), }); if (!response.ok) { @@ -29,17 +32,13 @@ async function fetchAndFormatDocuments(lastMessageContent: string) { } const data = await response.json(); - const result = data.metadatas[0].map((metadata: any, index: number) => { - return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`; - }).join(''); + const result = formatRetrievedDocuments(data); - console.log(result); - - return result; + return result || 'No relevant documents were retrieved.'; } catch (error) { console.error('Error fetching and formatting documents:', error); - throw error; // You may want to throw a more specific error object here + throw error; } } @@ -64,7 +63,7 @@ const handler = async (req: Request): Promise => { ${oneLine` You are a very enthusiastic AI assistant who loves to help people! Given the following information from - relevant documentation, answer the user's question using + relevant scientific documentation, answer the user's question using only that information, outputted in markdown format. `} @@ -75,7 +74,7 @@ const handler = async (req: Request): Promise => { `} ${oneLine` - Always include citations from the documentation. + Every factual claim must include citation keys from the documentation. `} `; @@ -85,7 +84,11 @@ const handler = async (req: Request): Promise => { const lastMessage = messages[messages.length - 1]; - const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content); + const baseUrl = new URL(req.url).origin; + const relevantDocuments = await fetchAndFormatDocuments( + baseUrl, + lastMessage.content, + ); let temperatureToUse = temperature; if (temperatureToUse == null) { @@ -100,9 +103,6 @@ const handler = async (req: Request): Promise => { encoding.free(); - console.log(model, promptToSend, temperatureToUse, key, messagesToSend); - - messagesToSend = [ { role: "user", @@ -121,6 +121,14 @@ const handler = async (req: Request): Promise => { ${oneLine` - Do not make up answers that are not provided in the documentation. `} + ${oneLine` + - Cite sources using the exact citation keys shown in square brackets, + for example [paper-title:p3:c2]. + `} + ${oneLine` + - Prefer sources with lower retrieval distance when multiple sources + contain similar information. + `} ${oneLine` - If you are unsure and the answer is not explicitly written in the documentation context, say diff --git a/ui/utils/server/scientific-rag.ts b/ui/utils/server/scientific-rag.ts new file mode 100644 index 0000000..eb36655 --- /dev/null +++ b/ui/utils/server/scientific-rag.ts @@ -0,0 +1,216 @@ +export type ScientificDocument = { + pageContent: string; + metadata: { + loc?: { + pageNumber?: number; + }; + pdf?: { + info?: { + Title?: string; + }; + }; + source?: string; + [key: string]: unknown; + }; +}; + +export type ScientificChunkMetadata = { + title: string; + page: number | string; + source: string; + section: string; + chunkIndex: number; + pageChunkIndex: number; + citationKey: string; +}; + +const SCIENTIFIC_SECTIONS = [ + 'abstract', + 'introduction', + 'background', + 'materials and methods', + 'methodology', + 'methods', + 'experimental setup', + 'experiments', + 'results', + 'evaluation', + 'discussion', + 'limitations', + 'conclusion', + 'references', +]; + +export const SCIENTIFIC_TEXT_SEPARATORS = [ + '\nAbstract', + '\nABSTRACT', + '\nIntroduction', + '\nINTRODUCTION', + '\nMethods', + '\nMETHODS', + '\nMaterials and Methods', + '\nMATERIALS AND METHODS', + '\nExperimental Setup', + '\nEXPERIMENTAL SETUP', + '\nExperiments', + '\nEXPERIMENTS', + '\nResults', + '\nRESULTS', + '\nEvaluation', + '\nEVALUATION', + '\nDiscussion', + '\nDISCUSSION', + '\nConclusion', + '\nCONCLUSION', + '\nReferences', + '\nREFERENCES', + '\n\n', + '\n', + '. ', + ' ', + '', +]; + +export const normalizeTitle = ( + titleFromMetadata: string | undefined, + fallbackTitle: string, +) => { + const title = titleFromMetadata?.trim(); + + return title && title.length > 0 ? title : fallbackTitle; +}; + +export const detectScientificSection = (content: string) => { + const firstLines = content + .split('\n') + .slice(0, 8) + .join(' ') + .toLowerCase(); + + for (const section of SCIENTIFIC_SECTIONS) { + const sectionRegex = new RegExp(`\\b${section}\\b`, 'i'); + + if (sectionRegex.test(firstLines)) { + return section; + } + } + + return 'body'; +}; + +export const buildCitationKey = ({ + title, + page, + pageChunkIndex, +}: { + title: string; + page: number | string; + pageChunkIndex: number; +}) => { + const slug = title + .toLowerCase() + .replace(/[^a-z0-9]+/g, '-') + .replace(/(^-|-$)/g, '') + .slice(0, 40); + + return `${slug || 'document'}:p${page}:c${pageChunkIndex + 1}`; +}; + +export const buildScientificMetadata = ( + document: ScientificDocument, + fallbackTitle: string, + chunkIndex: number, + pageChunkIndex = chunkIndex, +): ScientificChunkMetadata => { + const title = normalizeTitle(document.metadata.pdf?.info?.Title, fallbackTitle); + const page = document.metadata.loc?.pageNumber ?? 'unknown'; + const section = detectScientificSection(document.pageContent); + + return { + title, + page, + source: document.metadata.source ?? fallbackTitle, + section, + chunkIndex, + pageChunkIndex, + citationKey: buildCitationKey({ title, page, pageChunkIndex }), + }; +}; + +export const formatRetrievedDocument = ({ + content, + metadata, + distance, + index, +}: { + content: string; + metadata: Partial; + distance?: number; + index: number; +}) => { + const citationKey = metadata.citationKey ?? `source-${index + 1}`; + const page = metadata.page ?? 'unknown'; + const section = metadata.section ?? 'body'; + const pageChunkIndex = metadata.pageChunkIndex; + const scoreLine = + typeof distance === 'number' ? `Distance: ${distance.toFixed(4)}\n` : ''; + + return [ + `Source ${index + 1} [${citationKey}]`, + `Title: ${metadata.title ?? 'Untitled'}`, + `Page: ${page}`, + `Section: ${section}`, + typeof pageChunkIndex === 'number' + ? `Page chunk: ${pageChunkIndex + 1}` + : '', + scoreLine.trim(), + `Content: ${content}`, + ] + .filter(Boolean) + .join('\n'); +}; + +export const formatRetrievedDocuments = (data: { + documents?: unknown; + metadatas?: unknown; + distances?: unknown; +}) => { + const documents = Array.isArray(data.documents) + ? (data.documents[0] as unknown) + : undefined; + const metadatas = Array.isArray(data.metadatas) + ? (data.metadatas[0] as unknown) + : undefined; + const distances = Array.isArray(data.distances) + ? (data.distances[0] as unknown) + : undefined; + + if (!Array.isArray(documents) || documents.length === 0) { + return ''; + } + + return documents + .map((content, index) => { + if (typeof content !== 'string' || content.trim().length === 0) { + return ''; + } + + const metadata = + Array.isArray(metadatas) && typeof metadatas[index] === 'object' + ? (metadatas[index] as Partial) + : {}; + const distance = + Array.isArray(distances) && typeof distances[index] === 'number' + ? distances[index] + : undefined; + + return formatRetrievedDocument({ + content, + metadata, + distance, + index, + }); + }) + .filter(Boolean) + .join('\n\n---\n\n'); +};