-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathqwen_pybind.cpp
82 lines (70 loc) · 3.82 KB
/
qwen_pybind.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "qwen.h"
#include "tokenizer/tiktoken.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace qwen {
namespace py = pybind11;
using namespace pybind11::literals;
template <typename T>
static inline std::string to_string(const T &obj) {
std::ostringstream oss;
oss << obj;
return oss.str();
}
PYBIND11_MODULE(_C, m) {
m.doc() = "qwen.cpp python binding";
py::class_<QwenConfig>(m, "QwenConfig")
.def_readonly("dtype", &QwenConfig::dtype)
.def_readonly("vocab_size", &QwenConfig::vocab_size)
.def_readonly("hidden_size", &QwenConfig::hidden_size)
.def_readonly("num_attention_heads", &QwenConfig::num_attention_heads)
.def_readonly("num_kv_heads", &QwenConfig::num_kv_heads)
.def_readonly("num_hidden_layers", &QwenConfig::num_hidden_layers)
.def_readonly("intermediate_size", &QwenConfig::intermediate_size)
.def_readonly("max_length", &QwenConfig::max_length)
.def_readonly("eos_token_id", &QwenConfig::eos_token_id)
.def_readonly("pad_token_id", &QwenConfig::pad_token_id)
.def_readonly("im_start_id", &QwenConfig::im_start_id)
.def_readonly("im_end_id", &QwenConfig::im_end_id);
py::class_<tiktoken::tiktoken>(m, "tiktoken_cpp")
.def(
py::init<std::unordered_map<std::string, int>, std::unordered_map<std::string, int>, const std::string &>())
.def("encode_ordinary", &tiktoken::tiktoken::encode_ordinary)
.def("encode", &tiktoken::tiktoken::encode)
.def("encode_single_piece", &tiktoken::tiktoken::encode_single_piece)
.def("decode", &tiktoken::tiktoken::decode);
py::class_<QwenForCausalLM>(m, "QwenForCausalLM")
.def_readonly("config", &QwenForCausalLM::config)
.def("generate_next_token", &QwenForCausalLM::generate_next_token);
py::class_<QwenTokenizer>(m, "QwenTokenizer")
.def("encode", &QwenTokenizer::encode)
.def("decode", &QwenTokenizer::decode)
.def("encode_messages", &QwenTokenizer::encode_messages)
.def("decode_message", &QwenTokenizer::decode_message);
py::class_<GenerationConfig>(m, "GenerationConfig")
.def(py::init<int, int, bool, int, float, float, float, int>(), "max_length"_a = 2048,
"max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95,
"repetition_penalty"_a = 1.0, "num_threads"_a = 0)
.def_readwrite("max_length", &GenerationConfig::max_length)
.def_readwrite("max_context_length", &GenerationConfig::max_context_length)
.def_readwrite("do_sample", &GenerationConfig::do_sample)
.def_readwrite("top_k", &GenerationConfig::top_k)
.def_readwrite("top_p", &GenerationConfig::top_p)
.def_readwrite("temperature", &GenerationConfig::temperature)
.def_readwrite("repetition_penalty", &GenerationConfig::repetition_penalty)
.def_readwrite("num_threads", &GenerationConfig::num_threads);
py::class_<ChatMessage>(m, "ChatMessage")
.def(py::init<std::string, std::string>(), "role"_a, "content"_a)
.def("__repr__", &to_string<ChatMessage>)
.def("__str__", &to_string<ChatMessage>)
.def_readonly_static("ROLE_SYSTEM", &ChatMessage::ROLE_SYSTEM)
.def_readonly_static("ROLE_USER", &ChatMessage::ROLE_USER)
.def_readonly_static("ROLE_ASSISTANT", &ChatMessage::ROLE_ASSISTANT)
.def_readwrite("role", &ChatMessage::role)
.def_readwrite("content", &ChatMessage::content);
py::class_<Pipeline>(m, "Pipeline")
.def(py::init<const std::string &, const std::string &, const int>())
.def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); })
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); });
}
} // namespace qwen