Skip to content

Commit

Permalink
Provide an Inference engine for ctranslate2 (#2485)
Browse files Browse the repository at this point in the history
* inference engine for -py inference mode
* added ctranslated2 inference engine and simplified inference in vicuna chatbot
* added infer_file method to InferenceEngineCT2 class
* added onmt/tests/test_model_lm_ct2
  • Loading branch information
l-k-11235 authored Oct 18, 2023
1 parent 8bcc5ab commit 7dc0958
Show file tree
Hide file tree
Showing 15 changed files with 17,052 additions and 251 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,36 @@ jobs:
-min_length 5 \
-out /tmp/gen
diff data/data_lm/gen-sampling-beams-sol$(python -c "import torch; print(torch.__version__[0])").txt /tmp/gen && rm /tmp/gen
- name: Test py-LM inference engine
run: |
head data/src-test.txt > /tmp/src-test.txt
python onmt/tests/test_inference_engines.py \
-model onmt/tests/test_model_lm.pt \
-model_task lm \
-input_file /tmp/src-test.txt \
-inference_config_file data/inference-engine_py.yaml \
-inference_mode py \
-out /tmp/inference_engine_lm_py_outputs
- name: Test ct2-LM inference engine
run: |
head data/src-test.txt > /tmp/src-test.txt
python onmt/tests/test_inference_engines.py \
-model onmt/tests/test_model_lm_ct2 \
-model_task lm \
-input_file /tmp/src-test.txt \
-inference_config_file data/inference-engine_py.yaml \
-inference_mode ct2 \
-out /tmp/inference_engine_lm_py_outputs
- name: Test py-SEQ2SEQ inference engine
run: |
head data/src-test.txt > /tmp/src-test.txt
python onmt/tests/test_inference_engines.py \
-model onmt/tests/test_model.pt \
-model_task seq2seq \
-input_file /tmp/src-test.txt \
-inference_config_file data/inference-engine_py.yaml \
-inference_mode py \
-out /tmp/inference_engine_lm_py_outputs
- name: Test extract_vocabulary tool
run: |
python tools/extract_vocabulary.py \
Expand Down
11 changes: 11 additions & 0 deletions data/inference-engine_py.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
world_size: 0
max_length: 512
batch_type: sents
batch_size: 100
random_sampling_topk: 40
random_sampling_topp: 0.75
random_sampling_temp: 0.1
beam_size: 2
n_best: 2
src: None
model: None
88 changes: 51 additions & 37 deletions docs/source/examples/replicate_vicuna/ReplicateVicuna.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ The maximal context length will be set to 512.
Here is a short description of the content of your current directory:

- The OpenNMT-py repository.
- The `replicate_vicuna.yaml` file.
- The `replicate_vicuna.yaml` file with the finetuning options
- A subdirectory named "llama" with the llama chekpoints.
- The converted llama7B checkpoint (`llama7B-vicuna-onmt`) and the vocabulary (`vocab.txt`) that will be genenerated with OpenNMT tools.
- The llama7B checkpoint converted to `OpenNMT-py` format (`llama7B-vicuna-onmt`) and the vocabulary (`vocab.txt`). They will be genenerated with `OpenNMT-py` tools.
- A subdirectory named "dataAI" with the datasets for the finetuning.
- A subdirectory named "finetuned_llama7B" that will contain the finetuning samples, tensorboard logs and checkpoints.
- The `translate_opts.yaml` file with the translation options for the inference with `OpenNMT-py/onmt/bin/translate.py`.
- A subdirectory named "inputs" containing the `input_examples.txt` file with the input examples for the inference.
- A subdirectory named "finetuned_llama7B" that will contain the finetuning samples, the tensorboard logs and the checkpoints.
- The `translate_opts_py.yaml` file with the translation options for the inference with `translate.py`.
- The `translate_opts_ct2.yaml` file with the translation options for the inference with `cranslate2`.
- The `input_examples.txt` file with a few input examples.
- A subdirectory named "outputs" that will contain the inferred outputs of the finetuned model.
- The `simple_inference.py` file to compute vicuna's predictions from the `input_examples.txt` file, for the 2 different modes.
- The `chatbot.py` script (for the ctranslate2 inference with a gradio application).

## Dependencies
Expand Down Expand Up @@ -104,13 +106,13 @@ Below is an instruction that describes a task. Write a response that appropriate

## Finetuning

We provide an example of a finetuning configuration (`replicate_alpaca.yaml`). To enable the application of the LoRa method to the attention layers, the options of the checkpoint need to be overriden.
We provide an example of a finetuning configuration (`replicate_vicuna.yaml`). To enable the application of the LoRa method to the attention layers, the options of the checkpoint need to be overriden.


The finetuning can be launched with this command:

```shell
nohup python3 OpenNMT-py/onmt/bin/train.py -config replicate_alpaca-vicuna.yaml > finetenune-llama7B-vicuna-onmt.log &
nohup python3 OpenNMT-py/onmt/bin/train.py -config replicate_-vicuna.yaml > finetenune-llama7B-vicuna-onmt.log &
```

We can start by generating some samples (by turning `dump_samples` to True and `n_samples` to a strictly positive value).
Expand All @@ -137,48 +139,60 @@ python3 OpenNMT-py/tools/lora_weights.py\
--output finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat.pt
```

### Input examples
### Conversion to ctranslate format

The inputs need to follow the same pattern used in the finetuning examples.
To convert the concatenated checkpoint to ctranslate2 format, run the following command:

Let us create an "inputs" folder and save inside it the file named `input_examples.txt`.


### Inference with `translate.py`
```shell
python3 OpenNMT-py/onmt/bin/release_model.py \
--model finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat.pt \
--output finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat_CT2 \
--format ctranslate2 \
--quantization int8_float16
```

Let us create an "outputs" folder.
### Multi-round conversations with vicuna

To obtain the model's inference you can run this command:
We provide a gradio chatbot application that can be run with two different inference modes ("py" or ctranslate2).

Run one of the following commands:
```shell
nohup python3 OpenNMT-py/onmt/bin/translate.py\
-model finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat.pt \
-src inputs/input_examples.txt \
-output outputs/examples_llama7B-vicuna-onmt_step_4000.concat.txt \
-config translate_opts.yaml > infer.log &
python3 chatbot.py \
-inference_config_file translate_opts_py.yaml \
-inference_mode py \
-max_context_length 4096 \
-server_port 5000
```

Where `translate_opts.yaml` is the provided config with the translation options.
You can test other decoding methods and paramaters.

We have also provided a gradio application to chat with the model.
Or:

```shell
gradio chatbot.py
python3 chatbot.py \
-inference_config_file translate_opts_ct2.yaml \
-inference_mode ct2 \
-max_context_length 4096 \
-server_port 5000
```
You must use `inf_type = "-py"` at the beginning of the `chatbot.py` script.
Where `translate_opts_ct2.yaml` and `translate_opts_py.yaml` are the provided config with the translation options.
You can test other decoding methods and paramaters.

### Simple inference

### Inference with `CTranslate`
First we need to do the conversion to the ctranslate2 format.
To obtain the model's inference you can run this command:


```shell
python3 simple_inference.py \
-input_file input_examples.txt \
-inference_config_file translate_opts_py.yaml \
-inference_mode py \
-output_dir outputs
```
shell
python3 OpenNMT-py/onmt/bin/release_model.py \
--model finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat.pt \
--output finetuned_llama7B/llama7B-vicuna-onmt_step_4000.concat_CT2 \
--format ctranslate2 \
--quantization int8_float16
```
Or:

You can chat with the model using `inf_type = "ct2"` at the beginning of the `chatbot.py` script.
```shell
python3 simple_inference.py \
-input_file input_examples.txt \
-inference_config_file translate_opts_ct2.yaml \
-inference_mode ct2 \
-output_dir outputs
```
Loading

0 comments on commit 7dc0958

Please sign in to comment.