@@ -95,7 +95,7 @@ impl ObjectInterface for Socket {
95
95
96
96
match raw. state {
97
97
VsockState :: Listen => {
98
- raw. waker . register ( cx. waker ( ) ) ;
98
+ raw. rx_waker . register ( cx. waker ( ) ) ;
99
99
Poll :: Pending
100
100
}
101
101
VsockState :: ReceiveRequest => {
@@ -123,7 +123,7 @@ impl ObjectInterface for Socket {
123
123
response. buf_alloc = le32:: from_ne (
124
124
crate :: executor:: vsock:: RAW_SOCKET_BUFFER_SIZE as u32 ,
125
125
) ;
126
- response. fwd_cnt = le32:: from_ne ( 0 ) ;
126
+ response. fwd_cnt = le32:: from_ne ( raw . fwd_cnt ) ;
127
127
} ) ;
128
128
129
129
raw. state = VsockState :: Connected ;
@@ -174,7 +174,7 @@ impl ObjectInterface for Socket {
174
174
let len = core:: cmp:: min ( buffer. len ( ) , raw. buffer . len ( ) ) ;
175
175
176
176
if len == 0 {
177
- raw. waker . register ( cx. waker ( ) ) ;
177
+ raw. rx_waker . register ( cx. waker ( ) ) ;
178
178
Poll :: Pending
179
179
} else {
180
180
let tmp: Vec < _ > = raw. buffer . drain ( ..len) . collect ( ) ;
@@ -203,37 +203,54 @@ impl ObjectInterface for Socket {
203
203
204
204
async fn async_write ( & self , buffer : & [ u8 ] ) -> io:: Result < usize > {
205
205
let port = self . port . load ( Ordering :: Acquire ) ;
206
- let guard = VSOCK_MAP . lock ( ) ;
207
- let raw = guard. get_socket ( port) . ok_or ( Error :: EINVAL ) ?;
208
-
209
- match raw. state {
210
- VsockState :: Connected => {
211
- const HEADER_SIZE : usize = core:: mem:: size_of :: < Hdr > ( ) ;
212
- let mut driver_guard = hardware:: get_vsock_driver ( ) . unwrap ( ) . lock ( ) ;
213
- let local_cid = driver_guard. get_cid ( ) ;
214
-
215
- driver_guard. send_packet ( HEADER_SIZE + buffer. len ( ) , |virtio_buffer| {
216
- let response = unsafe { & mut * ( virtio_buffer. as_mut_ptr ( ) as * mut Hdr ) } ;
217
-
218
- response. src_cid = le64:: from_ne ( local_cid) ;
219
- response. dst_cid = le64:: from_ne ( raw. remote_cid as u64 ) ;
220
- response. src_port = le32:: from_ne ( port) ;
221
- response. dst_port = le32:: from_ne ( raw. remote_port ) ;
222
- response. len = le32:: from_ne ( buffer. len ( ) . try_into ( ) . unwrap ( ) ) ;
223
- response. type_ = le16:: from_ne ( Type :: Stream . into ( ) ) ;
224
- response. op = le16:: from_ne ( Op :: Rw . into ( ) ) ;
225
- response. flags = le32:: from_ne ( 0 ) ;
226
- response. buf_alloc =
227
- le32:: from_ne ( crate :: executor:: vsock:: RAW_SOCKET_BUFFER_SIZE as u32 ) ;
228
- response. fwd_cnt = le32:: from_ne ( raw. buffer . len ( ) . try_into ( ) . unwrap ( ) ) ;
229
-
230
- virtio_buffer[ HEADER_SIZE ..] . copy_from_slice ( buffer) ;
231
- } ) ;
232
-
233
- Ok ( buffer. len ( ) )
206
+ future:: poll_fn ( |cx| {
207
+ let mut guard = VSOCK_MAP . lock ( ) ;
208
+ let raw = guard. get_mut_socket ( port) . ok_or ( Error :: EINVAL ) ?;
209
+ let diff = raw. tx_cnt . abs_diff ( raw. peer_fwd_cnt ) ;
210
+
211
+ match raw. state {
212
+ VsockState :: Connected => {
213
+ if diff >= raw. peer_buf_alloc {
214
+ raw. tx_waker . register ( cx. waker ( ) ) ;
215
+ Poll :: Pending
216
+ } else {
217
+ const HEADER_SIZE : usize = core:: mem:: size_of :: < Hdr > ( ) ;
218
+ let mut driver_guard = hardware:: get_vsock_driver ( ) . unwrap ( ) . lock ( ) ;
219
+ let local_cid = driver_guard. get_cid ( ) ;
220
+ let len = core:: cmp:: min (
221
+ buffer. len ( ) ,
222
+ usize:: try_from ( raw. peer_buf_alloc - diff) . unwrap ( ) ,
223
+ ) ;
224
+
225
+ driver_guard. send_packet ( HEADER_SIZE + len, |virtio_buffer| {
226
+ let response =
227
+ unsafe { & mut * ( virtio_buffer. as_mut_ptr ( ) as * mut Hdr ) } ;
228
+
229
+ raw. tx_cnt = raw. tx_cnt . wrapping_add ( len. try_into ( ) . unwrap ( ) ) ;
230
+ response. src_cid = le64:: from_ne ( local_cid) ;
231
+ response. dst_cid = le64:: from_ne ( raw. remote_cid as u64 ) ;
232
+ response. src_port = le32:: from_ne ( port) ;
233
+ response. dst_port = le32:: from_ne ( raw. remote_port ) ;
234
+ response. len = le32:: from_ne ( len. try_into ( ) . unwrap ( ) ) ;
235
+ response. type_ = le16:: from_ne ( Type :: Stream . into ( ) ) ;
236
+ response. op = le16:: from_ne ( Op :: Rw . into ( ) ) ;
237
+ response. flags = le32:: from_ne ( 0 ) ;
238
+ response. buf_alloc = le32:: from_ne (
239
+ crate :: executor:: vsock:: RAW_SOCKET_BUFFER_SIZE as u32 ,
240
+ ) ;
241
+ response. fwd_cnt = le32:: from_ne ( raw. fwd_cnt ) ;
242
+
243
+ virtio_buffer[ HEADER_SIZE ..HEADER_SIZE + len]
244
+ . copy_from_slice ( & buffer[ ..len] ) ;
245
+ } ) ;
246
+
247
+ Poll :: Ready ( Ok ( len) )
248
+ }
249
+ }
250
+ _ => Poll :: Ready ( Err ( Error :: EIO ) ) ,
234
251
}
235
- _ => Err ( Error :: EIO ) ,
236
- }
252
+ } )
253
+ . await
237
254
}
238
255
}
239
256
0 commit comments