-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpackets.py
138 lines (108 loc) · 3.74 KB
/
packets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import struct
from enum import IntEnum, unique
from typing import Any
import helpers
# binary deserialization (reading)
@unique
class ResponseType(IntEnum):
# https://www.postgresql.org/docs/14/protocol-message-formats.html
AuthenticationRequest = ord("R")
BackendKeyData = ord("K")
BindComplete = ord("2")
CommandComplete = ord("C")
CopyData = ord("d")
CopyDone = ord("c")
CopyInResponse = ord("G")
CopyOutResponse = ord("H")
CopyBothResponse = ord("W")
DataRow = ord("D")
EmptyQueryResponse = ord("I")
ErrorResponse = ord("E")
FunctionCallResponse = ord("V")
NegotiateProtocolVersion = ord("v")
NoData = ord("n")
NoticeResponse = ord("N")
NotificationResponse = ord("A")
ParameterDescription = ord("t")
ParameterStatus = ord("S")
ParseComplete = ord("1")
PortalSuspended = ord("s")
ReadyForQuery = ord("Z")
RowDescription = ord("T")
def read_header(data: bytes) -> tuple[ResponseType, int]:
assert len(data) == 5
response_type = ResponseType(data[0])
response_len = struct.unpack(">i", data[1:])[0]
return response_type, response_len
class PacketReader:
def __init__(self, data_view: memoryview) -> None:
self.data_view = data_view
def read(self, fmt: str) -> tuple[Any, ...]:
size = struct.calcsize(fmt)
vals = struct.unpack_from(fmt, self.data_view[size:])
self.data_view = self.data_view[:size]
return vals
def read_bytes(self, count: int) -> bytes:
val = self.data_view[:count].tobytes()
self.data_view = self.data_view[count:]
return val
def read_u8(self) -> int:
val = self.data_view[0]
self.data_view = self.data_view[1:]
return val
def read_i16(self) -> int:
(val,) = struct.unpack(">h", self.data_view[:2])
self.data_view = self.data_view[2:]
return val
def read_i32(self) -> int:
(val,) = struct.unpack(">i", self.data_view[:4])
self.data_view = self.data_view[4:]
return val
def read_variadic_string(self) -> str:
length = self.read_i32()
val = self.data_view[:length].tobytes().decode()
self.data_view = self.data_view[length:]
return val
def read_nullterm_string(self) -> str:
# TODO: use a better method than bytes.find to avoid copy
remainder = self.data_view.tobytes()
length = remainder.find(b"\x00")
val = remainder[:length].decode()
self.data_view = self.data_view[length + 1 :]
return val
# binary serialization (writing)
# TODO: some sort of ordering of these packets
def startup(
proto_ver_major: int,
proto_ver_minor: int,
db_params: dict[bytes, bytes],
) -> bytes:
packet = bytearray()
packet += struct.pack(">hh", proto_ver_major, proto_ver_minor)
for param_name, param_value in db_params.items():
packet += param_name + b"\x00" + param_value + b"\x00"
# zero byte is required as terminator
# after the last name/value pair
packet += b"\x00"
# insert packet length at startup
packet[0:0] = struct.pack(">i", len(packet) + 4)
return packet
def termination() -> bytes:
packet = bytearray()
packet += b"X"
packet += struct.pack(">i", 4)
return packet
def query(query: str) -> bytes:
packet = bytearray()
packet += b"Q"
packet += struct.pack(">i", len(query) + 1 + 4)
packet += query.encode() + b"\x00"
return packet
def auth_md5_pass(db_user: bytes, db_pass: bytes, salt: bytes) -> bytes:
packet = bytearray()
packet += b"p"
packet += struct.pack(">i", 4 + 3 + 32 + 1) # length
packet += b"md5"
packet += helpers.md5hex(helpers.md5hex(db_pass + db_user) + salt)
packet += b"\x00"
return packet