|
3 | 3 | namespace torch_ext { |
4 | 4 |
|
5 | 5 | void registerPyOpDefs(pybind11::module& m) { |
| 6 | + pybind11::class_<MlaParams>(m, "MlaParams") |
| 7 | + .def(pybind11::init<>()) |
| 8 | + .def_readonly("batch_indice", &MlaParams::batch_indice) |
| 9 | + .def_readonly("positions", &MlaParams::positions) |
| 10 | + .def_readonly("paged_kv_last_page_len", &MlaParams::paged_kv_last_page_len) |
| 11 | + .def_readonly("kvlen", &MlaParams::kvlen) |
| 12 | + .def_readonly("page_indice", &MlaParams::page_indice) |
| 13 | + .def_readonly("page_indptr", &MlaParams::page_indptr) |
| 14 | + .def_readonly("qo_indptr", &MlaParams::qo_indptr); |
| 15 | + |
6 | 16 | pybind11::class_<KVCache>(m, "KVCache") |
7 | 17 | .def(pybind11::init<>()) |
8 | | - .def_readonly("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor") |
| 18 | + .def_readwrite("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor") |
9 | 19 | .def_readonly("v_cache_base", &KVCache::v_cache_base, "Value cache base tensor") |
10 | 20 | .def_readonly("k_scale_base", &KVCache::k_scale_base, "Key cache scale tensor") |
11 | 21 | .def_readonly("v_scale_base", &KVCache::v_scale_base, "Value cache scale tensor") |
@@ -43,12 +53,12 @@ void registerPyOpDefs(pybind11::module& m) { |
43 | 53 |
|
44 | 54 | pybind11::class_<PyAttentionInputs>(m, "PyAttentionInputs") |
45 | 55 | .def(pybind11::init<>()) |
46 | | - .def_readonly("is_prefill", &PyAttentionInputs::is_prefill) |
47 | | - .def_readonly("prefix_lengths", &PyAttentionInputs::prefix_lengths) |
48 | | - .def_readonly("sequence_lengths", &PyAttentionInputs::sequence_lengths) |
49 | | - .def_readonly("input_lengths", &PyAttentionInputs::input_lengths) |
| 56 | + .def_readwrite("is_prefill", &PyAttentionInputs::is_prefill) |
| 57 | + .def_readwrite("prefix_lengths", &PyAttentionInputs::prefix_lengths) |
| 58 | + .def_readwrite("sequence_lengths", &PyAttentionInputs::sequence_lengths) |
| 59 | + .def_readwrite("input_lengths", &PyAttentionInputs::input_lengths) |
50 | 60 | .def_readonly("cu_seqlens", &PyAttentionInputs::cu_seqlens) |
51 | | - .def_readonly("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host) |
| 61 | + .def_readwrite("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host) |
52 | 62 | .def_readonly("kv_cache_block_id_device", &PyAttentionInputs::kv_cache_block_id_device) |
53 | 63 | .def_readonly("dtype", &PyAttentionInputs::dtype) |
54 | 64 | .def_readonly("kv_block_offset", &PyAttentionInputs::kv_block_offset) |
|
0 commit comments