Retrievers inherit from the abstract Retriever
class and thus needs to follow this structure:
from models.retrievers.retriever import Retriever
class NewRetriever(Retriever):
def __init__(self, model_name=None):
self.model_name = 'new_retriever'
@abstractmethod
def __call__(self, kwargs):
# model inference e.g. return model(**kwargs)
pass
@abstractmethod
def collate_fn(self, batch, query_or_doc=None):
# implement collate_fn here
pass
@abstractmethod
def similarity_fn(self, q_embds, doc_embs):
# similarity fn to use e.g. torch.mm.(q_embs, doc_embs.t())
pass
We save it under models/retrievers/new_retriever.py
.
As the second step create a config for this model under config/retrievers/new_retriever.yaml
.
init_args:
_target_: models.retrievers.new_retriever.NewRetriever
model_name: "new_retriever"
batch_size: 1024
batch_size_sim: 256
To use the model add the argument retriever='new_retriever'
:
python3 main.py retriever='new_retriever'
Rerankers inherit from the abstract Reranker
class and thus needs to follow this structure:
from models.rerankers.reranker import Reranker
class NewReranker(Reranker):
def __init__(self, model_name=None):
self.model_name = 'new_reranker'
@abstractmethod
def __call__(self, kwargs):
# model inference e.g. self.model(**kwargs)
pass
@abstractmethod
def collate_fn(self, batch, query_or_doc=None):
# implement collate function
pass
We save it under models/rerankers/new_reranker.py
.
As the second step create a config for this model under config/rerankers/new_reranker.yaml
.
init_args:
_target_: models.rerankers.new_reranker.NewReranker
model_name: "new_reranker"
batch_size: 2048
To use the model add the argument reranker='new_reranker'
:
python3 main.py reranker='new_reranker'
The Generator inherits from the abstract Generator
class and thus needs to follow this structure:
from models.generators.generator import Generator
class NewGenerator(Generator):
def __init__(self, model_name=None):
self.model_name = 'new_generator'
@abstractmethod
def generate(self, inp):
# generation e.g. self.model(**inp)
pass
@abstractmethod
def collate_fn(self, inp):
pass
# only required for training
@abstractmethod
def prediction_step(self, model, model_input, label_ids=None):
# e.g.
# output = model(**model_input, labels=label_ids)
# return output.logits, output.loss
pass
We save it under models/generators/new_generator.py
.
As the second step create a config for this model under config/generators/new_generator.yaml
.
defaults:
- prompt: basic
init_args:
_target_: models.generators.new_generator.NewGenerator
model_name: "new_generator"
max_new_tokens: 128
batch_size: 32
max_inp_length: null
To use the model add the argument generator='new_generator'
:
python3 main.py generator='new_generator'
A dataset config contains two entries: doc
for the collection and query
for the queries.
A query dataset must contain the fields id
, wikipedia_id
(optional), content
, label
after the processing.
A document dataset must contain the fields id
, and content
after the processing.
Define a new dataset class in modules/dataset_processor.py
class NewDataset(Processor):
def __init__(self, *args, **kwargs):
# name under which the dataset will be saved 'datasets/new_dataset_{split}' (default)
dataset_name = 'new_dataset'
super().__init__(*args, **kwargs, dataset_name = dataset_name)
def process(self):
# load model
# e.g. for hf hub
#dataset = datasets.load_dataset('hf_dataset_name')
def map_fn(example):
# do some mapping
return example
dataset = dataset.map(map_fn, num_proc=self.num_proc)
return dataset
To use the dataset add a new dataset config e.g. config/dataset/new_config.yaml
using the new class NewDataset
for the collection (doc
field). As a query we are using an already existing Dataset KILTNQProcessor
. Additinally, add the field split
which defines which split within the dataset should be used.
test:
doc: null
query: null
dev:
doc:
init_args:
_target_: modules.dataset_processor.NewDataset
split: "full"
query:
init_args:
_target_: modules.dataset_processor.KILTNQProcessor
split: "validation"
train:
doc: null
query: null
All datasets can be overwritten by adding +overwrite_datasets=True
as an argument (Caution
: This might overwrite collections that take long long to encode). In case the indexing is interrupted you can continue encoding a collection from batch 1000 by additionally using the argument +continue_batch=1000
.
Prompts are stored in config/prompt/
via the argument prompt
.
Create a new prompt new_prompt
under config/prompt/new_prompt.yaml
An exmaple prompt could look like this. THE local variables (e.g. query
) will insterted into the formatted string within the respective models' format_instruction()
function.
Important
: empty spaces after a colon within the formatted string need to be escaped like to Question:\
.
system: "You are a helpful assistant. Your task is to extract relevant information from the provided documents and to answer questions accordingly."
user: f"Background:\ {docs}\n\nQuestion:\ {question}\nAnswer:"
system_without_docs: "You are a helpful assistant."
user_without_docs: f"Question:\ {question}\nAnswer:"
To use the prompt pass it as an argument:
python3 main.py generator='tinyllama-chat' prompt='new_prompt'
Using the oracle answers instead of generating using a LLM.
For running the generation simply use the generator oracle_answer
. For example:
python3 main.py dataset='kilt_nq' generator='oracle_answer'
To generate all oracle runs (trec runs) and save them in runs
execute the script scripts/kilt_generate_oracle.py
once.
Generating answers using Llama with the oracle provenances as documents.
For running the generation with e.g. llama-2-7b-chat
simply select orcale_provenance
as a retriever. For example:
python3 main.py dataset='kilt_nq' retriever='oracle_provenance' generator='llama-2-7b-chat'
If you want to develop new functionalities, or want to test basic BERGEN configurations:
Run all tests:
pytest tests/