@@ -2,7 +2,7 @@ use std::io::{ErrorKind, Read, Write};
2
2
use std:: { ffi:: c_void, ptr:: null} ;
3
3
use std:: { ptr:: null_mut, slice} ;
4
4
5
- use libc:: { size_t, EIO } ;
5
+ use libc:: { size_t, EINVAL , EIO } ;
6
6
use rustls:: {
7
7
Certificate , ClientConnection , ServerConnection , SupportedCipherSuite , ALL_CIPHER_SUITES ,
8
8
} ;
@@ -18,7 +18,7 @@ use crate::{
18
18
cipher:: { rustls_certificate, rustls_supported_ciphersuite} ,
19
19
error:: { map_error, rustls_io_result, rustls_result} ,
20
20
io:: { rustls_read_callback, rustls_write_callback} ,
21
- try_callback, try_mut_slice ,
21
+ try_callback,
22
22
} ;
23
23
use crate :: { ffi_panic_boundary, try_ref_from_ptr} ;
24
24
use crate :: { try_mut_from_ptr, try_slice, userdata_push, CastPtr } ;
@@ -148,15 +148,19 @@ impl rustls_connection {
148
148
) -> rustls_io_result {
149
149
ffi_panic_boundary ! {
150
150
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
151
- let out_n: & mut size_t = try_mut_from_ptr!( out_n) ;
151
+ if out_n. is_null( ) {
152
+ return rustls_io_result( EINVAL )
153
+ }
152
154
let callback: ReadCallback = try_callback!( callback) ;
153
155
154
156
let mut reader = CallbackReader { callback, userdata } ;
155
157
let n_read: usize = match conn. read_tls( & mut reader) {
156
158
Ok ( n) => n,
157
159
Err ( e) => return rustls_io_result( e. raw_os_error( ) . unwrap_or( EIO ) ) ,
158
160
} ;
159
- * out_n = n_read;
161
+ unsafe {
162
+ * out_n = n_read;
163
+ }
160
164
161
165
rustls_io_result( 0 )
162
166
}
@@ -181,15 +185,19 @@ impl rustls_connection {
181
185
) -> rustls_io_result {
182
186
ffi_panic_boundary ! {
183
187
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
184
- let out_n: & mut size_t = try_mut_from_ptr!( out_n) ;
188
+ if out_n. is_null( ) {
189
+ return rustls_io_result( EINVAL )
190
+ }
185
191
let callback: WriteCallback = try_callback!( callback) ;
186
192
187
193
let mut writer = CallbackWriter { callback, userdata } ;
188
194
let n_written: usize = match conn. write_tls( & mut writer) {
189
195
Ok ( n) => n,
190
196
Err ( e) => return rustls_io_result( e. raw_os_error( ) . unwrap_or( EIO ) ) ,
191
197
} ;
192
- * out_n = n_written;
198
+ unsafe {
199
+ * out_n = n_written;
200
+ }
193
201
194
202
rustls_io_result( 0 )
195
203
}
@@ -214,15 +222,19 @@ impl rustls_connection {
214
222
) -> rustls_io_result {
215
223
ffi_panic_boundary ! {
216
224
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
217
- let out_n: & mut size_t = try_mut_from_ptr!( out_n) ;
225
+ if out_n. is_null( ) {
226
+ return rustls_io_result( EINVAL )
227
+ }
218
228
let callback: VectoredWriteCallback = try_callback!( callback) ;
219
229
220
230
let mut writer = VectoredCallbackWriter { callback, userdata } ;
221
231
let n_written: usize = match conn. write_tls( & mut writer) {
222
232
Ok ( n) => n,
223
233
Err ( e) => return rustls_io_result( e. raw_os_error( ) . unwrap_or( EIO ) ) ,
224
234
} ;
225
- * out_n = n_written;
235
+ unsafe {
236
+ * out_n = n_written;
237
+ }
226
238
227
239
rustls_io_result( 0 )
228
240
}
@@ -345,14 +357,15 @@ impl rustls_connection {
345
357
) {
346
358
ffi_panic_boundary ! {
347
359
let conn: & Connection = try_ref_from_ptr!( conn) ;
348
- let protocol_out = try_mut_from_ptr!( protocol_out) ;
349
- let protocol_out_len = try_mut_from_ptr!( protocol_out_len) ;
360
+ if protocol_out. is_null( ) || protocol_out_len. is_null( ) {
361
+ return
362
+ }
350
363
match conn. alpn_protocol( ) {
351
- Some ( p) => {
364
+ Some ( p) => unsafe {
352
365
* protocol_out = p. as_ptr( ) ;
353
366
* protocol_out_len = p. len( ) ;
354
367
} ,
355
- None => {
368
+ None => unsafe {
356
369
* protocol_out = null( ) ;
357
370
* protocol_out_len = 0 ;
358
371
}
@@ -421,17 +434,16 @@ impl rustls_connection {
421
434
ffi_panic_boundary ! {
422
435
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
423
436
let write_buf: & [ u8 ] = try_slice!( buf, count) ;
424
- let out_n: & mut size_t = unsafe {
425
- match out_n. as_mut( ) {
426
- Some ( out_n) => out_n,
427
- None => return NullParameter ,
428
- }
429
- } ;
437
+ if out_n. is_null( ) {
438
+ return NullParameter
439
+ }
430
440
let n_written: usize = match conn. writer( ) . write( write_buf) {
431
441
Ok ( n) => n,
432
442
Err ( _) => return rustls_result:: Io ,
433
443
} ;
434
- * out_n = n_written;
444
+ unsafe {
445
+ * out_n = n_written;
446
+ }
435
447
rustls_result:: Ok
436
448
}
437
449
}
@@ -457,16 +469,28 @@ impl rustls_connection {
457
469
) -> rustls_result {
458
470
ffi_panic_boundary ! {
459
471
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
460
- let read_buf: & mut [ u8 ] = try_mut_slice!( buf, count) ;
461
- let out_n: & mut size_t = try_mut_from_ptr!( out_n) ;
472
+ if buf. is_null( ) {
473
+ return NullParameter
474
+ }
475
+ if out_n. is_null( ) {
476
+ return NullParameter
477
+ }
478
+
479
+ // Safety: the memory pointed at by buf must be initialized
480
+ // (required by documentation of this function).
481
+ let read_buf: & mut [ u8 ] = unsafe {
482
+ slice:: from_raw_parts_mut( buf, count)
483
+ } ;
462
484
463
485
let n_read: usize = match conn. reader( ) . read( read_buf) {
464
486
Ok ( n) => n,
465
487
Err ( e) if e. kind( ) == ErrorKind :: UnexpectedEof => return rustls_result:: UnexpectedEof ,
466
488
Err ( e) if e. kind( ) == ErrorKind :: WouldBlock => return rustls_result:: PlaintextEmpty ,
467
489
Err ( _) => return rustls_result:: Io ,
468
490
} ;
469
- * out_n = n_read;
491
+ unsafe {
492
+ * out_n = n_read;
493
+ }
470
494
rustls_result:: Ok
471
495
}
472
496
}
@@ -494,8 +518,12 @@ impl rustls_connection {
494
518
) -> rustls_result {
495
519
ffi_panic_boundary ! {
496
520
let conn: & mut Connection = try_mut_from_ptr!( conn) ;
497
- let read_buf: & mut [ std:: mem:: MaybeUninit <u8 >] = try_mut_slice!( buf, count) ;
498
- let out_n: & mut size_t = try_mut_from_ptr!( out_n) ;
521
+ if buf. is_null( ) || out_n. is_null( ) {
522
+ return NullParameter
523
+ }
524
+ let read_buf: & mut [ std:: mem:: MaybeUninit <u8 >] = unsafe {
525
+ slice:: from_raw_parts_mut( buf, count)
526
+ } ;
499
527
500
528
let mut read_buf = std:: io:: ReadBuf :: uninit( read_buf) ;
501
529
@@ -505,7 +533,9 @@ impl rustls_connection {
505
533
Err ( e) if e. kind( ) == ErrorKind :: WouldBlock => return rustls_result:: PlaintextEmpty ,
506
534
Err ( _) => return rustls_result:: Io ,
507
535
} ;
508
- * out_n = n_read;
536
+ unsafe {
537
+ * out_n = n_read;
538
+ }
509
539
rustls_result:: Ok
510
540
}
511
541
}
0 commit comments