@@ -107,12 +107,13 @@ async fn vsock_run() {
107
107
let port = header. dst_port . to_ne ( ) ;
108
108
let type_ = Type :: try_from ( header. type_ . to_ne ( ) ) . unwrap ( ) ;
109
109
let mut vsock_guard = VSOCK_MAP . lock ( ) ;
110
+ let header_cid: u32 = header. src_cid . to_ne ( ) . try_into ( ) . unwrap ( ) ;
110
111
111
112
if let Some ( raw) = vsock_guard. get_mut_socket ( port) {
112
113
if op == Op :: Request && raw. state == VsockState :: Listen && type_ == Type :: Stream
113
114
{
114
115
raw. state = VsockState :: ReceiveRequest ;
115
- raw. remote_cid = header . src_cid . to_ne ( ) . try_into ( ) . unwrap ( ) ;
116
+ raw. remote_cid = header_cid ;
116
117
raw. remote_port = header. src_port . to_ne ( ) ;
117
118
raw. peer_buf_alloc = header. buf_alloc . to_ne ( ) ;
118
119
raw. rx_waker . wake ( ) ;
@@ -121,21 +122,38 @@ async fn vsock_run() {
121
122
&& type_ == Type :: Stream
122
123
&& op == Op :: Rw
123
124
{
124
- raw. buffer . extend_from_slice ( data) ;
125
- raw. fwd_cnt = raw. fwd_cnt . wrapping_add ( u32:: try_from ( data. len ( ) ) . unwrap ( ) ) ;
126
- raw. peer_fwd_cnt = header. fwd_cnt . to_ne ( ) ;
127
- raw. tx_waker . wake ( ) ;
128
- raw. rx_waker . wake ( ) ;
129
- hdr = Some ( * header) ;
130
- fwd_cnt = raw. fwd_cnt ;
125
+ if raw. remote_cid == header_cid {
126
+ raw. buffer . extend_from_slice ( data) ;
127
+ raw. fwd_cnt =
128
+ raw. fwd_cnt . wrapping_add ( u32:: try_from ( data. len ( ) ) . unwrap ( ) ) ;
129
+ raw. peer_fwd_cnt = header. fwd_cnt . to_ne ( ) ;
130
+ raw. tx_waker . wake ( ) ;
131
+ raw. rx_waker . wake ( ) ;
132
+ hdr = Some ( * header) ;
133
+ fwd_cnt = raw. fwd_cnt ;
134
+ } else {
135
+ trace ! ( "Receive message from invalid source {}" , header_cid) ;
136
+ }
131
137
} else if op == Op :: CreditUpdate {
132
- raw. peer_fwd_cnt = header. fwd_cnt . to_ne ( ) ;
133
- raw. tx_waker . wake ( ) ;
138
+ if raw. remote_cid == header_cid {
139
+ raw. peer_fwd_cnt = header. fwd_cnt . to_ne ( ) ;
140
+ raw. tx_waker . wake ( ) ;
141
+ } else {
142
+ trace ! ( "Receive message from invalid source {}" , header_cid) ;
143
+ }
134
144
} else if op == Op :: Shutdown {
135
- raw. state = VsockState :: Shutdown ;
145
+ if raw. remote_cid == header_cid {
146
+ raw. state = VsockState :: Shutdown ;
147
+ } else {
148
+ trace ! ( "Receive message from invalid source {}" , header_cid) ;
149
+ }
136
150
} else {
137
- hdr = Some ( * header) ;
138
- fwd_cnt = raw. fwd_cnt ;
151
+ if raw. remote_cid == header_cid {
152
+ hdr = Some ( * header) ;
153
+ fwd_cnt = raw. fwd_cnt ;
154
+ } else {
155
+ trace ! ( "Receive message from invalid source {}" , header_cid) ;
156
+ }
139
157
}
140
158
}
141
159
} ) ;
0 commit comments