| 
3 | 3 | namespace torch_ext {  | 
4 | 4 | 
 
  | 
5 | 5 | void registerPyOpDefs(pybind11::module& m) {  | 
 | 6 | +    pybind11::class_<MlaParams>(m, "MlaParams")  | 
 | 7 | +        .def(pybind11::init<>())  | 
 | 8 | +        .def_readonly("batch_indice", &MlaParams::batch_indice)  | 
 | 9 | +        .def_readonly("positions", &MlaParams::positions)  | 
 | 10 | +        .def_readonly("paged_kv_last_page_len", &MlaParams::paged_kv_last_page_len)  | 
 | 11 | +        .def_readonly("kvlen", &MlaParams::kvlen)  | 
 | 12 | +        .def_readonly("page_indice", &MlaParams::page_indice)  | 
 | 13 | +        .def_readonly("page_indptr", &MlaParams::page_indptr)  | 
 | 14 | +        .def_readonly("qo_indptr", &MlaParams::qo_indptr);  | 
 | 15 | + | 
6 | 16 |     pybind11::class_<KVCache>(m, "KVCache")  | 
7 | 17 |         .def(pybind11::init<>())  | 
8 |  | -        .def_readonly("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor")  | 
 | 18 | +        .def_readwrite("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor")  | 
9 | 19 |         .def_readonly("v_cache_base", &KVCache::v_cache_base, "Value cache base tensor")  | 
10 | 20 |         .def_readonly("k_scale_base", &KVCache::k_scale_base, "Key cache scale tensor")  | 
11 | 21 |         .def_readonly("v_scale_base", &KVCache::v_scale_base, "Value cache scale tensor")  | 
@@ -43,12 +53,12 @@ void registerPyOpDefs(pybind11::module& m) {  | 
43 | 53 | 
 
  | 
44 | 54 |     pybind11::class_<PyAttentionInputs>(m, "PyAttentionInputs")  | 
45 | 55 |         .def(pybind11::init<>())  | 
46 |  | -        .def_readonly("is_prefill", &PyAttentionInputs::is_prefill)  | 
47 |  | -        .def_readonly("prefix_lengths", &PyAttentionInputs::prefix_lengths)  | 
48 |  | -        .def_readonly("sequence_lengths", &PyAttentionInputs::sequence_lengths)  | 
49 |  | -        .def_readonly("input_lengths", &PyAttentionInputs::input_lengths)  | 
 | 56 | +        .def_readwrite("is_prefill", &PyAttentionInputs::is_prefill)  | 
 | 57 | +        .def_readwrite("prefix_lengths", &PyAttentionInputs::prefix_lengths)  | 
 | 58 | +        .def_readwrite("sequence_lengths", &PyAttentionInputs::sequence_lengths)  | 
 | 59 | +        .def_readwrite("input_lengths", &PyAttentionInputs::input_lengths)  | 
50 | 60 |         .def_readonly("cu_seqlens", &PyAttentionInputs::cu_seqlens)  | 
51 |  | -        .def_readonly("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host)  | 
 | 61 | +        .def_readwrite("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host)  | 
52 | 62 |         .def_readonly("kv_cache_block_id_device", &PyAttentionInputs::kv_cache_block_id_device)  | 
53 | 63 |         .def_readonly("dtype", &PyAttentionInputs::dtype)  | 
54 | 64 |         .def_readonly("kv_block_offset", &PyAttentionInputs::kv_block_offset)  | 
 | 
0 commit comments