Skip to content

Commit 127a8ca

Browse files
feat: ensure retriever returns an image and send it to the LLM base64 encoded
1 parent f4b0961 commit 127a8ca

File tree

13 files changed

+180
-346
lines changed

13 files changed

+180
-346
lines changed

.changeset/large-plums-drum.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"llamaindex": patch
3+
---
4+
5+
Added support for multi-modal RAG (retriever and query engine) incl. an example
6+
Fixed persisting and loading image vector stores

examples/multimodal/data/1.jpg

-1.96 MB
Binary file not shown.

examples/multimodal/data/2.jpg

-4.77 MB
Binary file not shown.

examples/multimodal/data/3.jpg

-6.66 MB
Binary file not shown.

examples/multimodal/data/San Francisco.txt

Lines changed: 0 additions & 323 deletions
This file was deleted.

examples/multimodal/rag.ts

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import {
2+
CallbackManager,
3+
ImageDocument,
4+
ImageType,
25
MultiModalResponseSynthesizer,
6+
NodeWithScore,
37
OpenAI,
48
ServiceContext,
59
VectorStoreIndex,
@@ -21,23 +25,34 @@ export async function createIndex(serviceContext: ServiceContext) {
2125
}
2226

2327
async function main() {
28+
let images: ImageType[] = [];
29+
const callbackManager = new CallbackManager({
30+
onRetrieve: ({ query, nodes }) => {
31+
images = nodes
32+
.filter(({ node }: NodeWithScore) => node instanceof ImageDocument)
33+
.map(({ node }: NodeWithScore) => (node as ImageDocument).image);
34+
},
35+
});
2436
const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 });
2537
const serviceContext = serviceContextFromDefaults({
2638
llm,
2739
chunkSize: 512,
2840
chunkOverlap: 20,
41+
callbackManager,
2942
});
3043
const index = await createIndex(serviceContext);
3144

3245
const queryEngine = index.asQueryEngine({
3346
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
34-
// TODO: set imageSimilarityTopK: 1,
35-
retriever: index.asRetriever({ similarityTopK: 2 }),
47+
retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }),
3648
});
3749
const result = await queryEngine.query(
38-
"what are Vincent van Gogh's famous paintings",
50+
"Tell me more about Vincent van Gogh's famous paintings",
51+
);
52+
console.log(result.response, "\n");
53+
images.forEach((image) =>
54+
console.log(`Image retrieved and used in inference: ${image.toString()}`),
3955
);
40-
console.log(result.response);
4156
}
4257

4358
main().catch(console.error);

examples/multimodal/retrieve.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import {
55
TextNode,
66
VectorStoreIndex,
77
} from "llamaindex";
8-
import * as path from "path";
98

109
export async function createIndex() {
1110
// set up vector store index with two vector stores, one for text, the other for images
@@ -37,7 +36,7 @@ async function main() {
3736
continue;
3837
}
3938
if (node instanceof ImageNode) {
40-
console.log(`Image: ${path.join(__dirname, node.id_)}`);
39+
console.log(`Image: ${node.getUrl()}`);
4140
} else if (node instanceof TextNode) {
4241
console.log("Text:", (node as TextNode).text.substring(0, 128));
4342
}

packages/core/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"@xenova/transformers": "^2.10.0",
1111
"assemblyai": "^4.0.0",
1212
"crypto-js": "^4.2.0",
13+
"file-type": "^18.7.0",
1314
"js-tiktoken": "^1.0.8",
1415
"lodash": "^4.17.21",
1516
"mammoth": "^1.6.0",

packages/core/src/Node.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import CryptoJS from "crypto-js";
2+
import path from "path";
23
import { v4 as uuidv4 } from "uuid";
34

45
export enum NodeRelationship {
@@ -304,6 +305,12 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
304305
getType(): ObjectType {
305306
return ObjectType.IMAGE;
306307
}
308+
309+
getUrl(): URL {
310+
// id_ stores the relative path, convert it to the URL of the file
311+
const absPath = path.resolve(this.id_);
312+
return new URL(`file://${absPath}`);
313+
}
307314
}
308315

309316
export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {

packages/core/src/embeddings/utils.ts

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import _ from "lodash";
22
import { ImageType } from "../Node";
33
import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
4-
import { VectorStoreQueryMode } from "../storage";
4+
import { DEFAULT_FS, VectorStoreQueryMode } from "../storage";
55
import { SimilarityType } from "./types";
66

77
/**
@@ -185,6 +185,16 @@ export function getTopKMMREmbeddings(
185185
return [resultSimilarities, resultIds];
186186
}
187187

188+
async function blobToDataUrl(input: Blob) {
189+
const { fileTypeFromBuffer } = await import("file-type");
190+
const buffer = Buffer.from(await input.arrayBuffer());
191+
const type = await fileTypeFromBuffer(buffer);
192+
if (!type) {
193+
throw new Error("Unsupported image type");
194+
}
195+
return "data:" + type.mime + ";base64," + buffer.toString("base64");
196+
}
197+
188198
export async function readImage(input: ImageType) {
189199
const { RawImage } = await import("@xenova/transformers");
190200
if (input instanceof Blob) {
@@ -195,3 +205,53 @@ export async function readImage(input: ImageType) {
195205
throw new Error(`Unsupported input type: ${typeof input}`);
196206
}
197207
}
208+
209+
export async function imageToString(input: ImageType): Promise<string> {
210+
if (input instanceof Blob) {
211+
// if the image is a Blob, convert it to a base64 data URL
212+
return await blobToDataUrl(input);
213+
} else if (_.isString(input)) {
214+
return input;
215+
} else if (input instanceof URL) {
216+
return input.toString();
217+
} else {
218+
throw new Error(`Unsupported input type: ${typeof input}`);
219+
}
220+
}
221+
222+
export function stringToImage(input: string): ImageType {
223+
if (input.startsWith("data:")) {
224+
// if the input is a base64 data URL, convert it back to a Blob
225+
const base64Data = input.split(",")[1];
226+
const byteArray = Buffer.from(base64Data, "base64");
227+
return new Blob([byteArray]);
228+
} else if (input.startsWith("http://") || input.startsWith("https://")) {
229+
return new URL(input);
230+
} else if (_.isString(input)) {
231+
return input;
232+
} else {
233+
throw new Error(`Unsupported input type: ${typeof input}`);
234+
}
235+
}
236+
237+
export async function imageToDataUrl(input: ImageType): Promise<string> {
238+
// first ensure, that the input is a Blob
239+
if (
240+
(input instanceof URL && input.protocol === "file:") ||
241+
_.isString(input)
242+
) {
243+
// string or file URL
244+
const fs = DEFAULT_FS;
245+
const dataBuffer = await fs.readFile(
246+
input instanceof URL ? input.pathname : input,
247+
);
248+
input = new Blob([dataBuffer]);
249+
} else if (!(input instanceof Blob)) {
250+
if (input instanceof URL) {
251+
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
252+
} else {
253+
throw new Error(`Unsupported input type: ${typeof input}`);
254+
}
255+
}
256+
return await blobToDataUrl(input);
257+
}

packages/core/src/indices/vectorStore/VectorIndexRetriever.ts

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import { Event } from "../../callbacks/CallbackManager";
2-
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
3-
import { BaseEmbedding } from "../../embeddings";
41
import { globalsHelper } from "../../GlobalsHelper";
5-
import { Metadata, NodeWithScore } from "../../Node";
2+
import { ImageNode, Metadata, NodeWithScore } from "../../Node";
63
import { BaseRetriever } from "../../Retriever";
74
import { ServiceContext } from "../../ServiceContext";
5+
import { Event } from "../../callbacks/CallbackManager";
6+
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
7+
import { BaseEmbedding } from "../../embeddings";
88
import {
99
VectorStoreQuery,
1010
VectorStoreQueryMode,
@@ -18,20 +18,23 @@ import { VectorStoreIndex } from "./VectorStoreIndex";
1818

1919
export class VectorIndexRetriever implements BaseRetriever {
2020
index: VectorStoreIndex;
21-
similarityTopK;
21+
similarityTopK: number;
22+
imageSimilarityTopK: number;
2223
private serviceContext: ServiceContext;
2324

2425
constructor({
2526
index,
2627
similarityTopK,
28+
imageSimilarityTopK,
2729
}: {
2830
index: VectorStoreIndex;
2931
similarityTopK?: number;
32+
imageSimilarityTopK?: number;
3033
}) {
3134
this.index = index;
3235
this.serviceContext = this.index.serviceContext;
33-
3436
this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
37+
this.imageSimilarityTopK = imageSimilarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
3538
}
3639

3740
async retrieve(
@@ -51,7 +54,11 @@ export class VectorIndexRetriever implements BaseRetriever {
5154
query: string,
5255
preFilters?: unknown,
5356
): Promise<NodeWithScore[]> {
54-
const q = await this.buildVectorStoreQuery(this.index.embedModel, query);
57+
const q = await this.buildVectorStoreQuery(
58+
this.index.embedModel,
59+
query,
60+
this.similarityTopK,
61+
);
5562
const result = await this.index.vectorStore.query(q, preFilters);
5663
return this.buildNodeListFromQueryResult(result);
5764
}
@@ -64,6 +71,7 @@ export class VectorIndexRetriever implements BaseRetriever {
6471
const q = await this.buildVectorStoreQuery(
6572
this.index.imageEmbedModel,
6673
query,
74+
this.imageSimilarityTopK,
6775
);
6876
const result = await this.index.imageVectorStore.query(q, preFilters);
6977
return this.buildNodeListFromQueryResult(result);
@@ -89,13 +97,14 @@ export class VectorIndexRetriever implements BaseRetriever {
8997
protected async buildVectorStoreQuery(
9098
embedModel: BaseEmbedding,
9199
query: string,
100+
similarityTopK: number,
92101
): Promise<VectorStoreQuery> {
93102
const queryEmbedding = await embedModel.getQueryEmbedding(query);
94103

95104
return {
96105
queryEmbedding: queryEmbedding,
97106
mode: VectorStoreQueryMode.DEFAULT,
98-
similarityTopK: this.similarityTopK,
107+
similarityTopK: similarityTopK,
99108
};
100109
}
101110

@@ -108,6 +117,12 @@ export class VectorIndexRetriever implements BaseRetriever {
108117
}
109118

110119
const node = this.index.indexStruct.nodesDict[result.ids[i]];
120+
// XXX: Hack, if it's an image node, we reconstruct the image from the URL
121+
// Alternative: Store image in doc store and retrieve it here
122+
if (node instanceof ImageNode) {
123+
node.image = node.getUrl();
124+
}
125+
111126
nodesWithScores.push({
112127
node: node,
113128
score: result.similarities[i],

packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import { MessageContentDetail } from "../ChatEngine";
2-
import { MetadataMode, NodeWithScore, splitNodesByType } from "../Node";
2+
import {
3+
ImageNode,
4+
MetadataMode,
5+
NodeWithScore,
6+
splitNodesByType,
7+
} from "../Node";
38
import { Response } from "../Response";
49
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
510
import { Event } from "../callbacks/CallbackManager";
11+
import { imageToDataUrl } from "../embeddings";
612
import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt";
713
import { BaseSynthesizer } from "./types";
814

@@ -34,15 +40,19 @@ export class MultiModalResponseSynthesizer implements BaseSynthesizer {
3440
// TODO: use builders to generate context
3541
const context = textChunks.join("\n\n");
3642
const textPrompt = this.textQATemplate({ context, query });
37-
// TODO: get images from imageNodes
43+
const images = await Promise.all(
44+
imageNodes.map(async (node: ImageNode) => {
45+
return {
46+
type: "image_url",
47+
image_url: {
48+
url: await imageToDataUrl(node.image),
49+
},
50+
} as MessageContentDetail;
51+
}),
52+
);
3853
const prompt: MessageContentDetail[] = [
3954
{ type: "text", text: textPrompt },
40-
{
41-
type: "image_url",
42-
image_url: {
43-
url: "https://upload.wikimedia.org/wikipedia/commons/b/b0/Vincent_van_Gogh_%281853-1890%29_Caf%C3%A9terras_bij_nacht_%28place_du_Forum%29_Kr%C3%B6ller-M%C3%BCller_Museum_Otterlo_23-8-2016_13-35-40.JPG",
44-
},
45-
},
55+
...images,
4656
];
4757
let response = await this.serviceContext.llm.complete(prompt, parentEvent);
4858
return new Response(response.message.content, nodes);

0 commit comments

Comments
 (0)