Skip to content

Commit 96b37dc

Browse files
authored
[Feature]: Add Cohere2 (#1065)
* modeling_cohere2 * edits * update generate * fix errors * cohere2 test script * revise scalar dtype * fix graph errors * readme & pre-commit fix
1 parent ec34b6e commit 96b37dc

File tree

6 files changed

+1332
-0
lines changed

6 files changed

+1332
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
# Cohere2
3+
4+
## Overview
5+
6+
[C4AI Command R7B](https://cohere.com/blog/command-r7b) is an open weights research release of a 7B billion parameter model developed by Cohere and Cohere For AI. It has advanced capabilities optimized for various use cases, including reasoning, summarization, question answering, and code. The model is trained to perform sophisticated tasks including Retrieval Augmented Generation (RAG) and tool use. The model also has powerful agentic capabilities that can use and combine multiple tools over multiple steps to accomplish more difficult tasks. It obtains top performance on enterprise-relevant code use cases. C4AI Command R7B is a multilingual model trained on 23 languages.
7+
8+
The model features three layers with sliding window attention (window size 4096) and ROPE for efficient local context modeling and relative positional encoding. A fourth layer uses global attention without positional embeddings, enabling unrestricted token interactions across the entire sequence.
9+
10+
The model has been trained on 23 languages: English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Arabic, Chinese, Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, and Persian.
11+
12+
13+
14+
## Checkpoints
15+
16+
You can download the checkpoints using the following command:
17+
```bash
18+
huggingface-cli download --resume-download CohereForAI/c4ai-command-r7b-12-2024
19+
```
20+
21+
## Examples
22+
23+
Here's an example usage:
24+
25+
```python
26+
from time import time
27+
28+
from transformers import AutoTokenizer
29+
30+
import mindspore as ms
31+
from mindspore import Tensor
32+
33+
from mindone.transformers import Cohere2ForCausalLM
34+
35+
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
36+
tokenizer = AutoTokenizer.from_pretrained(model_id)
37+
model = Cohere2ForCausalLM.from_pretrained(model_id, mindspore_dtype=ms.float16)
38+
39+
messages = [{"role": "user", "content": "How do plants make energy?"}]
40+
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="np")
41+
42+
input_ids = (
43+
Tensor(input_ids) if (len(input_ids.shape) == 2 and input_ids.shape[0] == 1) else Tensor(input_ids).unsqueeze(0)
44+
) # (1, L)
45+
infer_start = time()
46+
output = model.generate(
47+
input_ids,
48+
max_new_tokens=100,
49+
do_sample=True,
50+
temperature=0.3,
51+
cache_implementation="static",
52+
)
53+
print(f"Inference time: {time() - infer_start:.3f}s")
54+
print(tokenizer.decode(output[0], skip_special_tokens=True))
55+
```
56+
57+
See `./generate.py` for detailed usage.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from time import time
2+
3+
from transformers import AutoTokenizer
4+
5+
import mindspore as ms
6+
from mindspore import Tensor
7+
8+
from mindone.transformers import Cohere2ForCausalLM
9+
10+
ms.set_context(mode=ms.PYNATIVE_MODE)
11+
12+
13+
def main():
14+
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
15+
tokenizer = AutoTokenizer.from_pretrained(model_id)
16+
model = Cohere2ForCausalLM.from_pretrained(model_id, mindspore_dtype=ms.float16)
17+
18+
messages = [{"role": "user", "content": "How do plants make energy?"}]
19+
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="np")
20+
21+
input_ids = (
22+
Tensor(input_ids) if (len(input_ids.shape) == 2 and input_ids.shape[0] == 1) else Tensor(input_ids).unsqueeze(0)
23+
) # (1, L)
24+
infer_start = time()
25+
output = model.generate(
26+
input_ids,
27+
max_new_tokens=100,
28+
do_sample=True,
29+
temperature=0.3,
30+
cache_implementation="static",
31+
)
32+
print(f"Inference time: {time() - infer_start:.3f}s")
33+
print(tokenizer.decode(output[0], skip_special_tokens=True))
34+
35+
36+
if __name__ == "__main__":
37+
main()

mindone/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
CLIPVisionModel,
100100
CLIPVisionModelWithProjection,
101101
)
102+
from .models.cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel
102103
from .models.deberta import (
103104
DebertaForMaskedLM,
104105
DebertaForQuestionAnswering,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from transformers import Cohere2Config
2+
3+
from .modeling_cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel

0 commit comments

Comments
 (0)