Skip to content

Commit

Permalink
feat: Add jina-ai late-chunking sample codes
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Nov 19, 2024
1 parent ff0a6a7 commit 1d70781
Show file tree
Hide file tree
Showing 19 changed files with 2,517 additions and 0 deletions.
31 changes: 31 additions & 0 deletions LLMs/RAG/late-chunking/.github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Run Tests

on:
pull_request:
types: [opened, synchronize, reopened]
push:
branches:
- main

env:
JINA_API_TOKEN: ${{ secrets.JINA_API_TOKEN }}

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Run tests
run: pytest tests
77 changes: 77 additions & 0 deletions LLMs/RAG/late-chunking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Late Chunking of Short Chunks in Long-Context Embedding Models

For many applications, encoding a whole text document into a single embedding representation is not useful. Many applications require retrieving smaller parts of the text and dense vector-based information retrieval systems often perform better with smaller text segments because of the limited information capacity of embedding vectors.

![img.png](img/rag.png)


RAG (Retrieval Augmented Generations) is one of the best known applications to require splitting document collections into smaller text chunks. These chunks are typically stored in a vector database with vector representations created by a text embedding model.
At runtime, the same embedding model encodes a query text into a vector representation, which is used to identify relevant stored text chunks. These are them passed to a large language model (LLM) which synthesizes a response to the query based on the retrieved texts.

## Context Problem


This simple RAG approach is not without challenges. Long distance contextual dependencies, i.e. when the relevant information is spread over multiple chunks and taking text segments out of context makes them useless, are particularly poorly handled by this approach.
![img.png](img/context-problem.png)
In the image above one can see an Wikipedia article that is split into chunks of sentences.
One can see that phrases like "its" and "the city" referencing "Berlin" which is mentioned only in the first sentence, e.g., it is harder for the embedding model to link it to the respective entity to produce a high-quality embedding representation.


For example, if we split a Wikipedia article into sentence-length segments, as in the example above, a RAG system might not be able to answer a query like "What is the population of Berlin?" The city name and the population never appear together in a single segment, and lacking any larger document context.
An LLM to which one of the segments is presented cannot resolve the anaphoric references like "it" or "the city".

## Context-Sensitive Chunking

To overcome this problem, we take advantage of the long input sequences that recent embedding models like [`jina-embeddings-v2-base-en`](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) can process.
These models support much longer input texts, for example, 8192 tokens for `jina-embeddings-v2-base-en` or roughly ten standard pages of text. Text segments of this size are much less likely to have contextual dependencies that can only be resolved with a larger context.
However, we still need vector representations of much smaller chunks of text, in part because of the limited input sizes of LLMs but primarily because of the limited information capacity of short embedding vectors.

![img.png](img/method.png)


The simple encoding approach (as seen on the left side of the image above) chunks texts before processing them, using sentences, paragraphs, and maximum length limits to split text _a priori_, and then applying an embedding model to the resulting chunks.
Late Chunking, instead, first applies the transformer part from the embedding model to the entire text, or the largest part of it possible. This generates a sequence of vector representations for each token that encompass textual information from the entire text.
To generate a single embedding for a text, many embedding models apply _mean pooling_ to these token representations to output a single vector. Late Chunking instead applies mean pooling to smaller segments of this sequence of token vectors, producing embeddings for each chunk that take into account the entire text.

## The Effect of Context-Sensitive Chunking

This has immediately measurable concrete effects on retrieval. As an example, in case of "the city" and "Berlin" in a Wikipedia article, the vectors representing "the city" contain information connecting it to the previous mention of "Berlin", making it a much better match for queries involving that city name.

You can see that in numerical results below, which compares the embedding of the string "Berlin" to various sentences from the article about Berlin. The column "Traditional Similarity" is the similarity values using _a priori_ chunking, and "Late Chunking Similarity" is with context-sensitive chunking.

| Text | Similarity Traditional | Similarity Late Chunking |
|---------------------------------------------------------------------------------------------------------------------------------------|------------------------|-------------------------------|
| Berlin is the capital and largest city of Germany, both by area and by population." | 0.84862185 | 0.849546 |
| Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. | 0.7084338 | 0.82489026 |
| The city is also one of the states of Germany, and is the third smallest state in the country in terms of area. | 0.7534553 | 0.84980094 |

As you can see the similarity scores for the first chunk that contains "Berlin" are very close to each other.
For the other two chunks they siginificantly differ, as the late chunking dramatically improves matching on sentences that do not explicitly use the word "Berlin" but have anaphoric references to it.

## Evaluation on Retrieval Tasks


To verify the effectiveness of this approach beyond a few toy examples, we tested it with some of the retrieval benchmarks from [BeIR](https://github.com/beir-cellar/beir).
Those retrieval tasks consist of a query set, a corpus of text documents, and a QRels file that stores information about the IDs of documents that are relevant for each query.
To identify the relevant documents of a query, one can chunk the documents, encode them into an embedding index, and determine for each query embedding the most similar chunks (kNN).
As each chunk corresponds to a document, one can convert the kNN ranking of chunks into a kNN ranking of documents (for documents occurring multiple times in the ranking, only the first occurrence is retained).
After that, one can compare the resulting ranking with the ranking corresponding to the ground-truth QRels file and calculate retrieval metrics like nDCG@10.
We run this evaluation for various BeIR datasets with traditional chunking and our novel late chunking method.
To split texts into chunks, we choose a straightforward method, which chunks the tests into strings of 256 tokens.
Both the traditional and late chunking tests used the [jina-embeddings-v2-small-en](https://huggingface.co/jinaai/jina-embeddings-v2-small-en) model.

| Dataset | AVG Document Length (characters) | Traditional Chunking (nDCG@10) | Late Chunking (nDCG@10) | No Chunking (nDCG@10) |
|-----------|----------------------------------|--------------------------------|--------------------------------------|-----------------------|
| SciFact | 1498.4 | 64.20% | **66.10%** | 63.89% |
| TRECCOVID | 1116.7 | 63.36% | 64.70% | **65.18%** |
| FiQA2018 | 767.2 | 33.25% | **33.84%** | 33.43% |
| NFCorpus | 1589.8 | 23.46% | 29.98% | **30.40%** |
| Quora | 62.2 | 87.19% | 87.19% | 87.19% |

In all cases, late chunking improved the score. In some cases, it also outperforms encoding the whole document into a single embedding, while for other datasets, no chunking performs best. However, this only makes sense if one does not need to rank chunks. One can also see that the average length of the documents correlates with greater improvement in the nDCG scores through late chunking.

To reporoduce the evaluation, you can install the dependencies with `pip install .` and run the following script for the tasks "SciFactChunked", "TRECCOVIDChunked", "FiQA2018Chunked", "NFCorpusChunked", and "QuoraChunked":

```bash
python3 run_chunked_eval.py --task-name {TASK_NAME}
```
56 changes: 56 additions & 0 deletions LLMs/RAG/late-chunking/chunked_pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
def chunk_by_sentences(input_text: str, tokenizer: callable):
"""
Split the input text into sentences using the tokenizer
:param input_text: The text snippet to split into sentences
:param tokenizer: The tokenizer to use
:return: A tuple containing the list of text chunks and their corresponding token spans
"""
inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
token_offsets = inputs['offset_mapping'][0]
token_ids = inputs['input_ids'][0]
chunk_positions = [
(i, int(start + 1))
for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
if token_id == punctuation_mark_id
and (
token_offsets[i + 1][0] - token_offsets[i][1] > 0
or token_ids[i + 1] == sep_id
)
]
chunks = [
input_text[x[1] : y[1]]
for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
]
span_annotations = [
(x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
]
return chunks, span_annotations


def chunked_pooling(
model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
token_embeddings = model_output[0]
outputs = []
for embeddings, annotations in zip(token_embeddings, span_annotation):
if (
max_length is not None
): # remove annotations which go bejond the max-length of the model
annotations = [
(start, min(end, max_length - 1))
for (start, end) in annotations
if start < (max_length - 1)
]
pooled_embeddings = [
embeddings[start:end].sum(dim=0) / (end - start)
for start, end in annotations
if (end - start) >= 1
]
pooled_embeddings = [
embedding.float().detach().cpu().numpy() for embedding in pooled_embeddings
]
outputs.append(pooled_embeddings)

return outputs
Loading

0 comments on commit 1d70781

Please sign in to comment.