Skip to content

Commit

Permalink
[product doc] implement highlight summarizer (elastic#206578)
Browse files Browse the repository at this point in the history
## Summary

Fix elastic#205921

- Implements a new summary strategy for the product documentation, based
on `semantic_text` highlights
- set that new strategy as the default one

### Why ?

Until now, in case of excessive token count, we were using a LLM based
summarizer. Realistically, highlights will always be worse than calling
a LLM for a "in context summary", but from my testing, highlights seem
"good enough", and the speed difference (instant for highlights vs
multiple seconds, up to a dozen, for the LLM summary) is very
significant, and seems overall worth it.

The main upside with that change, given that requesting the product doc
will be waaaay faster, is that we can then tweak the assistant's
instruction to more aggressively call the product_doc tool between each
user message without the risk of the user experience being impacted
(waiting way longer between messages). - *which will be done as a
follow-up*

### How to test ?

Install the product doc, ask questions to the assistant, check the tool
calls (sorry, don't have a better option atm...)

Note: that works with both versions of the product doc artifacts, so
don't need the dev repository
  • Loading branch information
pgayvallet authored Jan 16, 2025
1 parent 7e48400 commit c9286ec
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ import { LlmTasksPlugin } from './plugin';
export { config } from './config';

export type { LlmTasksPluginSetup, LlmTasksPluginStart };
export type {
RetrieveDocumentationAPI,
RetrieveDocumentationParams,
RetrieveDocumentationResult,
RetrieveDocumentationResultDoc,
} from './tasks';

export const plugin: PluginInitializer<
LlmTasksPluginSetup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@
* 2.0.
*/

export { retrieveDocumentation } from './retrieve_documentation';
export {
retrieveDocumentation,
type RetrieveDocumentationParams,
type RetrieveDocumentationResultDoc,
type RetrieveDocumentationResult,
type RetrieveDocumentationAPI,
} from './retrieve_documentation';
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
export { retrieveDocumentation } from './retrieve_documentation';
export type {
RetrieveDocumentationAPI,
RetrieveDocumentationResultDoc,
RetrieveDocumentationResult,
RetrieveDocumentationParams,
} from './types';
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import { httpServerMock } from '@kbn/core/server/mocks';
import { loggerMock, type MockedLogger } from '@kbn/logging-mocks';
import type { DocSearchResult } from '@kbn/product-doc-base-plugin/server/services/search';

import { retrieveDocumentation } from './retrieve_documentation';

import { truncate, count as countTokens } from '../../utils/tokens';
jest.mock('../../utils/tokens');
const truncateMock = truncate as jest.MockedFn<typeof truncate>;
Expand All @@ -26,12 +26,16 @@ describe('retrieveDocumentation', () => {
let searchDocAPI: jest.Mock;
let retrieve: ReturnType<typeof retrieveDocumentation>;

const createResult = (parts: Partial<DocSearchResult> = {}): DocSearchResult => {
const createResult = (
parts: Partial<DocSearchResult> = {},
{ highlights = [] }: { highlights?: string[] } = {}
): DocSearchResult => {
return {
title: 'title',
content: 'content',
url: 'url',
productName: 'kibana',
highlights,
...parts,
};
};
Expand All @@ -56,6 +60,7 @@ describe('retrieveDocumentation', () => {
const result = await retrieve({
searchTerm: 'What is Kibana?',
products: ['kibana'],
tokenReductionStrategy: 'highlight',
request,
max: 5,
connectorId: '.my-connector',
Expand All @@ -72,7 +77,65 @@ describe('retrieveDocumentation', () => {
query: 'What is Kibana?',
products: ['kibana'],
max: 5,
highlights: 4,
});
});

it('calls the search API with highlights=0 when using a different summary strategy', async () => {
searchDocAPI.mockResolvedValue({ results: [] });

await retrieve({
searchTerm: 'What is Kibana?',
products: ['kibana'],
tokenReductionStrategy: 'truncate',
request,
max: 5,
connectorId: '.my-connector',
functionCalling: 'simulated',
});

expect(searchDocAPI).toHaveBeenCalledTimes(1);
expect(searchDocAPI).toHaveBeenCalledWith({
query: 'What is Kibana?',
products: ['kibana'],
max: 5,
highlights: 0,
});
});

it('reduces the document length using the highlights strategy', async () => {
searchDocAPI.mockResolvedValue({
results: [
createResult({ content: 'content-1' }, { highlights: ['hl1-1', 'hl1-2'] }),
createResult({ content: 'content-2' }, { highlights: ['hl2-1', 'hl2-2'] }),
createResult({ content: 'content-3' }, { highlights: ['hl3-1', 'hl3-2'] }),
],
});

countTokensMock.mockImplementation((text) => {
if (text === 'content-2') {
return 150;
} else {
return 50;
}
});
truncateMock.mockImplementation((val) => val);

const result = await retrieve({
searchTerm: 'What is Kibana?',
request,
connectorId: '.my-connector',
maxDocumentTokens: 100,
tokenReductionStrategy: 'highlight',
});

expect(result.documents.length).toEqual(3);
expect(result.documents[0].content).toEqual('content-1');
expect(result.documents[1].content).toEqual('hl2-1\n\nhl2-2');
expect(result.documents[2].content).toEqual('content-3');

expect(truncateMock).toHaveBeenCalledTimes(1);
expect(truncateMock).toHaveBeenCalledWith('hl2-1\n\nhl2-2', 100);
});

it('reduces the document length using the truncate strategy', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import type { Logger } from '@kbn/logging';
import type { OutputAPI } from '@kbn/inference-common';
import type { ProductDocSearchAPI } from '@kbn/product-doc-base-plugin/server';
import type { ProductDocSearchAPI, DocSearchResult } from '@kbn/product-doc-base-plugin/server';
import { truncate, count as countTokens } from '../../utils/tokens';
import type { RetrieveDocumentationAPI } from './types';
import { summarizeDocument } from './summarize_document';
Expand All @@ -32,10 +32,40 @@ export const retrieveDocumentation =
functionCalling,
max = MAX_DOCUMENTS_DEFAULT,
maxDocumentTokens = MAX_TOKENS_DEFAULT,
tokenReductionStrategy = 'summarize',
tokenReductionStrategy = 'highlight',
}) => {
const applyTokenReductionStrategy = async (doc: DocSearchResult): Promise<string> => {
let content: string;
switch (tokenReductionStrategy) {
case 'highlight':
content = doc.highlights.join('\n\n');
break;
case 'summarize':
const extractResponse = await summarizeDocument({
searchTerm,
documentContent: doc.content,
outputAPI,
connectorId,
functionCalling,
});
content = extractResponse.summary;
break;
case 'truncate':
content = doc.content;
break;
}
return truncate(content, maxDocumentTokens);
};

try {
const { results } = await searchDocAPI({ query: searchTerm, products, max });
const highlights =
tokenReductionStrategy === 'highlight' ? calculateHighlightCount(maxDocumentTokens) : 0;
const { results } = await searchDocAPI({
query: searchTerm,
products,
max,
highlights,
});

log.debug(`searching with term=[${searchTerm}] returned ${results.length} documents`);

Expand All @@ -49,25 +79,15 @@ export const retrieveDocumentation =

let content = document.content;
if (docHasTooManyTokens) {
if (tokenReductionStrategy === 'summarize') {
const extractResponse = await summarizeDocument({
searchTerm,
documentContent: document.content,
outputAPI,
connectorId,
functionCalling,
});
content = truncate(extractResponse.summary, maxDocumentTokens);
} else {
content = truncate(document.content, maxDocumentTokens);
}
content = await applyTokenReductionStrategy(document);
}

log.debug(`done processing document [${document.url}]`);
return {
title: document.title,
url: document.url,
content,
summarized: docHasTooManyTokens,
};
})
);
Expand All @@ -86,3 +106,9 @@ export const retrieveDocumentation =
return { success: false, documents: [] };
}
};

const AVG_TOKENS_PER_HIGHLIGHT = 250;

const calculateHighlightCount = (maxTokensPerDoc: number): number => {
return Math.ceil(maxTokensPerDoc / AVG_TOKENS_PER_HIGHLIGHT);
};
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ export interface RetrieveDocumentationParams {
maxDocumentTokens?: number;
/**
* The token reduction strategy to apply for documents exceeding max token count.
* - truncate: Will keep the N first tokens
* - summarize: Will call the LLM asking to generate a contextualized summary of the document
* - "highlight": Use Elasticsearch semantic highlighter to build a summary (concatenating highlights)
* - "truncate": Will keep the N first tokens
* - "summarize": Will call the LLM asking to generate a contextualized summary of the document
*
* Overall, `summarize` is way more efficient, but significantly slower, given that an additional
* Overall, `summarize` is more efficient, but significantly slower, given that an additional
* LLM call will be performed.
*
* Defaults to `summarize`
* Defaults to `highlight`
*/
tokenReductionStrategy?: 'truncate' | 'summarize';
tokenReductionStrategy?: 'highlight' | 'truncate' | 'summarize';
/**
* The request that initiated the task.
*/
Expand All @@ -53,20 +54,39 @@ export interface RetrieveDocumentationParams {
* Id of the LLM connector to use for the task.
*/
connectorId: string;
/**
* Optional functionCalling parameter to pass down to the inference APIs.
*/
functionCalling?: FunctionCallingMode;
}

export interface RetrievedDocument {
/**
* Individual result item in a {@link RetrieveDocumentationResult}
*/
export interface RetrieveDocumentationResultDoc {
/** title of the document */
title: string;
/** full url to the online documentation */
url: string;
/** full content of the doc article */
content: string;
/** true if content exceeded max token length and had to go through token reduction */
summarized: boolean;
}

/**
* Response type for {@link RetrieveDocumentationAPI}
*/
export interface RetrieveDocumentationResult {
/** whether the call was successful or not */
success: boolean;
documents: RetrievedDocument[];
/** List of results for this search */
documents: RetrieveDocumentationResultDoc[];
}

/**
* Retrieve documentation API
*/
export type RetrieveDocumentationAPI = (
options: RetrieveDocumentationParams
) => Promise<RetrieveDocumentationResult>;
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ import { ProductDocBasePlugin } from './plugin';
export { config } from './config';

export type { ProductDocBaseSetupContract, ProductDocBaseStartContract };
export type { SearchApi as ProductDocSearchAPI } from './services/search/types';
export type {
SearchApi as ProductDocSearchAPI,
DocSearchOptions,
DocSearchResult,
DocSearchResponse,
} from './services/search/types';

export const plugin: PluginInitializer<
ProductDocBaseSetupContract,
ProductDocBaseStartContract,
ProductDocBaseSetupDependencies,
ProductDocBaseStartDependencies
> = async (pluginInitializerContext: PluginInitializerContext<ProductDocBaseConfig>) =>
new ProductDocBasePlugin(pluginInitializerContext);
> = async (pluginInitializerContext: PluginInitializerContext<ProductDocBaseConfig>) => {
return new ProductDocBasePlugin(pluginInitializerContext);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { elasticsearchServiceMock } from '@kbn/core/server/mocks';
import { performSearch } from './perform_search';

describe('performSearch', () => {
let esClient: ReturnType<typeof elasticsearchServiceMock.createElasticsearchClient>;

beforeEach(() => {
esClient = elasticsearchServiceMock.createElasticsearchClient();

esClient.search.mockResolvedValue({ hits: { hits: [] } } as any);
});

it('calls esClient.search with the correct parameters', async () => {
await performSearch({
searchQuery: 'query',
highlights: 3,
size: 3,
index: ['index1', 'index2'],
client: esClient,
});

expect(esClient.search).toHaveBeenCalledTimes(1);
expect(esClient.search).toHaveBeenCalledWith({
index: ['index1', 'index2'],
size: 3,
query: expect.any(Object),
highlight: {
fields: {
content_body: expect.any(Object),
},
},
});
});

it('calls esClient.search without highlight when highlights=0', async () => {
await performSearch({
searchQuery: 'query',
highlights: 0,
size: 3,
index: ['index1', 'index2'],
client: esClient,
});

expect(esClient.search).toHaveBeenCalledTimes(1);
expect(esClient.search).toHaveBeenCalledWith(
expect.not.objectContaining({
highlight: expect.any(Object),
})
);
});
});
Loading

0 comments on commit c9286ec

Please sign in to comment.