Skip to content

Commit 7190b40

Browse files
committed
revise buffer management compliant with the virtio standard
1 parent d82fba8 commit 7190b40

File tree

4 files changed

+83
-49
lines changed

4 files changed

+83
-49
lines changed

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 1024;
1414
/// Default keep alive interval in milliseconds
1515
#[cfg(feature = "tcp")]
1616
pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000;
17+
18+
#[cfg(feature = "vsock")]
19+
pub(crate) const VSOCK_PACKET_SIZE: u32 = 8192;

src/drivers/vsock/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl RxQueue {
8181
vq: None,
8282
poll_sender,
8383
poll_receiver,
84-
packet_size: 8192u32 + mem::size_of::<Hdr>() as u32,
84+
packet_size: crate::VSOCK_PACKET_SIZE + mem::size_of::<Hdr>() as u32,
8585
}
8686
}
8787

@@ -163,7 +163,7 @@ impl TxQueue {
163163
pub fn new() -> Self {
164164
Self {
165165
vq: None,
166-
packet_length: 8192u32 + mem::size_of::<Hdr>() as u32,
166+
packet_length: crate::VSOCK_PACKET_SIZE + mem::size_of::<Hdr>() as u32,
167167
}
168168
}
169169

@@ -238,7 +238,7 @@ impl EventQueue {
238238
vq: None,
239239
poll_sender,
240240
poll_receiver,
241-
packet_size: 1024u32,
241+
packet_size: 128u32,
242242
}
243243
}
244244

src/executor/vsock.rs

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,13 @@ pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024;
6767
pub(crate) struct RawSocket {
6868
pub remote_cid: u32,
6969
pub remote_port: u32,
70+
pub fwd_cnt: u32,
71+
pub peer_fwd_cnt: u32,
72+
pub peer_buf_alloc: u32,
73+
pub tx_cnt: u32,
7074
pub state: VsockState,
71-
pub waker: WakerRegistration,
75+
pub rx_waker: WakerRegistration,
76+
pub tx_waker: WakerRegistration,
7277
pub buffer: Vec<u8>,
7378
}
7479

@@ -77,8 +82,13 @@ impl RawSocket {
7782
Self {
7883
remote_cid: 0,
7984
remote_port: 0,
85+
fwd_cnt: 0,
86+
peer_fwd_cnt: 0,
87+
peer_buf_alloc: 0,
88+
tx_cnt: 0,
8089
state,
81-
waker: WakerRegistration::new(),
90+
rx_waker: WakerRegistration::new(),
91+
tx_waker: WakerRegistration::new(),
8292
buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE),
8393
}
8494
}
@@ -90,7 +100,7 @@ async fn vsock_run() {
90100
const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
91101
let mut driver_guard = driver.lock();
92102
let mut hdr: Option<Hdr> = None;
93-
let mut fwd_cnt: Option<u32> = None;
103+
let mut fwd_cnt: u32 = 0;
94104

95105
driver_guard.process_packet(|header, data| {
96106
let op = Op::try_from(header.op.to_ne()).unwrap();
@@ -104,23 +114,28 @@ async fn vsock_run() {
104114
raw.state = VsockState::ReceiveRequest;
105115
raw.remote_cid = header.src_cid.to_ne().try_into().unwrap();
106116
raw.remote_port = header.src_port.to_ne();
107-
raw.waker.wake();
117+
raw.peer_buf_alloc = header.buf_alloc.to_ne();
118+
raw.rx_waker.wake();
108119
} else if (raw.state == VsockState::Connected
109120
|| raw.state == VsockState::Shutdown)
110121
&& type_ == Type::Stream
111122
&& op == Op::Rw
112123
{
113124
raw.buffer.extend_from_slice(data);
114-
raw.waker.wake();
125+
raw.fwd_cnt = raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap());
126+
raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
127+
raw.tx_waker.wake();
128+
raw.rx_waker.wake();
129+
hdr = Some(*header);
130+
fwd_cnt = raw.fwd_cnt;
115131
} else if op == Op::CreditUpdate {
116-
debug!("CrediteUpdate currently not supported: {:?}", header);
132+
raw.peer_fwd_cnt = header.fwd_cnt.to_ne();
133+
raw.tx_waker.wake();
117134
} else if op == Op::Shutdown {
118135
raw.state = VsockState::Shutdown;
119136
} else {
120137
hdr = Some(*header);
121-
if op == Op::CreditRequest {
122-
fwd_cnt = Some(raw.buffer.len().try_into().unwrap());
123-
}
138+
fwd_cnt = raw.fwd_cnt;
124139
}
125140
}
126141
});
@@ -135,17 +150,16 @@ async fn vsock_run() {
135150
response.dst_port = hdr.src_port;
136151
response.len = le32::from_ne(0);
137152
response.type_ = hdr.type_;
138-
if let Some(fwd_cnt) = fwd_cnt {
139-
// update fwd_cnt
153+
if hdr.op.to_ne() == Op::CreditRequest.into() || hdr.op.to_ne() == Op::Rw.into()
154+
{
140155
response.op = le16::from_ne(Op::CreditUpdate.into());
141-
response.fwd_cnt = le32::from_ne(fwd_cnt);
142156
} else {
143157
// reset connection
144158
response.op = le16::from_ne(Op::Rst.into());
145-
response.fwd_cnt = le32::from_ne(0);
146159
}
147160
response.flags = le32::from_ne(0);
148161
response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32);
162+
response.fwd_cnt = le32::from_ne(fwd_cnt);
149163
});
150164
}
151165

src/fd/socket/vsock.rs

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl ObjectInterface for Socket {
9595

9696
match raw.state {
9797
VsockState::Listen => {
98-
raw.waker.register(cx.waker());
98+
raw.rx_waker.register(cx.waker());
9999
Poll::Pending
100100
}
101101
VsockState::ReceiveRequest => {
@@ -123,7 +123,7 @@ impl ObjectInterface for Socket {
123123
response.buf_alloc = le32::from_ne(
124124
crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
125125
);
126-
response.fwd_cnt = le32::from_ne(0);
126+
response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
127127
});
128128

129129
raw.state = VsockState::Connected;
@@ -174,7 +174,7 @@ impl ObjectInterface for Socket {
174174
let len = core::cmp::min(buffer.len(), raw.buffer.len());
175175

176176
if len == 0 {
177-
raw.waker.register(cx.waker());
177+
raw.rx_waker.register(cx.waker());
178178
Poll::Pending
179179
} else {
180180
let tmp: Vec<_> = raw.buffer.drain(..len).collect();
@@ -203,37 +203,54 @@ impl ObjectInterface for Socket {
203203

204204
async fn async_write(&self, buffer: &[u8]) -> io::Result<usize> {
205205
let port = self.port.load(Ordering::Acquire);
206-
let guard = VSOCK_MAP.lock();
207-
let raw = guard.get_socket(port).ok_or(Error::EINVAL)?;
208-
209-
match raw.state {
210-
VsockState::Connected => {
211-
const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
212-
let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
213-
let local_cid = driver_guard.get_cid();
214-
215-
driver_guard.send_packet(HEADER_SIZE + buffer.len(), |virtio_buffer| {
216-
let response = unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) };
217-
218-
response.src_cid = le64::from_ne(local_cid);
219-
response.dst_cid = le64::from_ne(raw.remote_cid as u64);
220-
response.src_port = le32::from_ne(port);
221-
response.dst_port = le32::from_ne(raw.remote_port);
222-
response.len = le32::from_ne(buffer.len().try_into().unwrap());
223-
response.type_ = le16::from_ne(Type::Stream.into());
224-
response.op = le16::from_ne(Op::Rw.into());
225-
response.flags = le32::from_ne(0);
226-
response.buf_alloc =
227-
le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32);
228-
response.fwd_cnt = le32::from_ne(raw.buffer.len().try_into().unwrap());
229-
230-
virtio_buffer[HEADER_SIZE..].copy_from_slice(buffer);
231-
});
232-
233-
Ok(buffer.len())
206+
future::poll_fn(|cx| {
207+
let mut guard = VSOCK_MAP.lock();
208+
let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?;
209+
let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt);
210+
211+
match raw.state {
212+
VsockState::Connected => {
213+
if diff >= raw.peer_buf_alloc {
214+
raw.tx_waker.register(cx.waker());
215+
Poll::Pending
216+
} else {
217+
const HEADER_SIZE: usize = core::mem::size_of::<Hdr>();
218+
let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
219+
let local_cid = driver_guard.get_cid();
220+
let len = core::cmp::min(
221+
buffer.len(),
222+
usize::try_from(raw.peer_buf_alloc - diff).unwrap(),
223+
);
224+
225+
driver_guard.send_packet(HEADER_SIZE + len, |virtio_buffer| {
226+
let response =
227+
unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) };
228+
229+
raw.tx_cnt = raw.tx_cnt.wrapping_add(len.try_into().unwrap());
230+
response.src_cid = le64::from_ne(local_cid);
231+
response.dst_cid = le64::from_ne(raw.remote_cid as u64);
232+
response.src_port = le32::from_ne(port);
233+
response.dst_port = le32::from_ne(raw.remote_port);
234+
response.len = le32::from_ne(len.try_into().unwrap());
235+
response.type_ = le16::from_ne(Type::Stream.into());
236+
response.op = le16::from_ne(Op::Rw.into());
237+
response.flags = le32::from_ne(0);
238+
response.buf_alloc = le32::from_ne(
239+
crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
240+
);
241+
response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
242+
243+
virtio_buffer[HEADER_SIZE..HEADER_SIZE + len]
244+
.copy_from_slice(&buffer[..len]);
245+
});
246+
247+
Poll::Ready(Ok(len))
248+
}
249+
}
250+
_ => Poll::Ready(Err(Error::EIO)),
234251
}
235-
_ => Err(Error::EIO),
236-
}
252+
})
253+
.await
237254
}
238255
}
239256

0 commit comments

Comments
 (0)