Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ impl rustls_certificate {
) -> rustls_result {
ffi_panic_boundary! {
let cert = try_ref_from_ptr!(cert);
let out_der_data: &mut *const u8 = try_mut_from_ptr!(out_der_data);
let out_der_len: &mut size_t = try_mut_from_ptr!(out_der_len);
if out_der_data.is_null() || out_der_len.is_null() {
return NullParameter
}
let der = cert.as_ref();
*out_der_data = der.as_ptr();
*out_der_len = der.len();
unsafe {
*out_der_data = der.as_ptr();
*out_der_len = der.len();
}
rustls_result::Ok
}
}
Expand Down
80 changes: 55 additions & 25 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::{ErrorKind, Read, Write};
use std::{ffi::c_void, ptr::null};
use std::{ptr::null_mut, slice};

use libc::{size_t, EIO};
use libc::{size_t, EINVAL, EIO};
use rustls::{
Certificate, ClientConnection, ServerConnection, SupportedCipherSuite, ALL_CIPHER_SUITES,
};
Expand All @@ -18,7 +18,7 @@ use crate::{
cipher::{rustls_certificate, rustls_supported_ciphersuite},
error::{map_error, rustls_io_result, rustls_result},
io::{rustls_read_callback, rustls_write_callback},
try_callback, try_mut_slice,
try_callback,
};
use crate::{ffi_panic_boundary, try_ref_from_ptr};
use crate::{try_mut_from_ptr, try_slice, userdata_push, CastPtr};
Expand Down Expand Up @@ -148,15 +148,19 @@ impl rustls_connection {
) -> rustls_io_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if out_n.is_null() {
return rustls_io_result(EINVAL)
}
let callback: ReadCallback = try_callback!(callback);

let mut reader = CallbackReader { callback, userdata };
let n_read: usize = match conn.read_tls(&mut reader) {
Ok(n) => n,
Err(e) => return rustls_io_result(e.raw_os_error().unwrap_or(EIO)),
};
*out_n = n_read;
unsafe {
*out_n = n_read;
}

rustls_io_result(0)
}
Expand All @@ -181,15 +185,19 @@ impl rustls_connection {
) -> rustls_io_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if out_n.is_null() {
return rustls_io_result(EINVAL)
}
let callback: WriteCallback = try_callback!(callback);

let mut writer = CallbackWriter { callback, userdata };
let n_written: usize = match conn.write_tls(&mut writer) {
Ok(n) => n,
Err(e) => return rustls_io_result(e.raw_os_error().unwrap_or(EIO)),
};
*out_n = n_written;
unsafe {
*out_n = n_written;
}

rustls_io_result(0)
}
Expand All @@ -214,15 +222,19 @@ impl rustls_connection {
) -> rustls_io_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if out_n.is_null() {
return rustls_io_result(EINVAL)
}
let callback: VectoredWriteCallback = try_callback!(callback);

let mut writer = VectoredCallbackWriter { callback, userdata };
let n_written: usize = match conn.write_tls(&mut writer) {
Ok(n) => n,
Err(e) => return rustls_io_result(e.raw_os_error().unwrap_or(EIO)),
};
*out_n = n_written;
unsafe {
*out_n = n_written;
}

rustls_io_result(0)
}
Expand Down Expand Up @@ -345,14 +357,15 @@ impl rustls_connection {
) {
ffi_panic_boundary! {
let conn: &Connection = try_ref_from_ptr!(conn);
let protocol_out = try_mut_from_ptr!(protocol_out);
let protocol_out_len = try_mut_from_ptr!(protocol_out_len);
if protocol_out.is_null() || protocol_out_len.is_null() {
return
}
match conn.alpn_protocol() {
Some(p) => {
Some(p) => unsafe {
*protocol_out = p.as_ptr();
*protocol_out_len = p.len();
},
None => {
None => unsafe {
*protocol_out = null();
*protocol_out_len = 0;
}
Expand Down Expand Up @@ -421,17 +434,16 @@ impl rustls_connection {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let write_buf: &[u8] = try_slice!(buf, count);
let out_n: &mut size_t = unsafe {
match out_n.as_mut() {
Some(out_n) => out_n,
None => return NullParameter,
}
};
if out_n.is_null() {
return NullParameter
}
let n_written: usize = match conn.writer().write(write_buf) {
Ok(n) => n,
Err(_) => return rustls_result::Io,
};
*out_n = n_written;
unsafe {
*out_n = n_written;
}
rustls_result::Ok
}
}
Expand All @@ -457,16 +469,28 @@ impl rustls_connection {
) -> rustls_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let read_buf: &mut [u8] = try_mut_slice!(buf, count);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if buf.is_null() {
return NullParameter
}
if out_n.is_null() {
return NullParameter
}

// Safety: the memory pointed at by buf must be initialized
// (required by documentation of this function).
let read_buf: &mut [u8] = unsafe {
slice::from_raw_parts_mut(buf, count)
};

let n_read: usize = match conn.reader().read(read_buf) {
Ok(n) => n,
Err(e) if e.kind() == ErrorKind::UnexpectedEof => return rustls_result::UnexpectedEof,
Err(e) if e.kind() == ErrorKind::WouldBlock => return rustls_result::PlaintextEmpty,
Err(_) => return rustls_result::Io,
};
*out_n = n_read;
unsafe {
*out_n = n_read;
}
rustls_result::Ok
}
}
Expand Down Expand Up @@ -494,8 +518,12 @@ impl rustls_connection {
) -> rustls_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let read_buf: &mut [std::mem::MaybeUninit<u8>] = try_mut_slice!(buf, count);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if buf.is_null() || out_n.is_null() {
return NullParameter
}
let read_buf: &mut [std::mem::MaybeUninit<u8>] = unsafe {
slice::from_raw_parts_mut(buf, count)
};

let mut read_buf = std::io::ReadBuf::uninit(read_buf);

Expand All @@ -505,7 +533,9 @@ impl rustls_connection {
Err(e) if e.kind() == ErrorKind::WouldBlock => return rustls_result::PlaintextEmpty,
Err(_) => return rustls_result::Io,
};
*out_n = n_read;
unsafe {
*out_n = n_read;
}
rustls_result::Ok
}
}
Expand Down
27 changes: 12 additions & 15 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{cmp::min, convert::TryFrom, fmt::Display, slice};
use std::cmp::min;
use std::convert::TryFrom;
use std::fmt::Display;

use crate::ffi_panic_boundary;
use libc::{c_char, c_uint, size_t};
Expand All @@ -25,23 +27,18 @@ impl rustls_result {
out_n: *mut size_t,
) {
ffi_panic_boundary! {
let write_buf: &mut [u8] = unsafe {
let out_n: &mut size_t = match out_n.as_mut() {
Some(out_n) => out_n,
None => return,
};
*out_n = 0;
if buf.is_null() {
return;
}
slice::from_raw_parts_mut(buf as *mut u8, len as usize)
};
if buf.is_null() {
return
}
if out_n.is_null() {
return
}
let result: rustls_result = rustls_result::try_from(result).unwrap_or(rustls_result::InvalidParameter);
let error_str = result.to_string();
let len: usize = min(write_buf.len() - 1, error_str.len());
write_buf[..len].copy_from_slice(&error_str.as_bytes()[..len]);
let out_len: usize = min(len - 1, error_str.len());
unsafe {
*out_n = len;
std::ptr::copy_nonoverlapping(error_str.as_ptr() as *mut c_char, buf, out_len);
*out_n = out_len;
}
}
}
Expand Down
10 changes: 1 addition & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//! [You may also want to read the rustls-ffi README](https://github.com/rustls/rustls-ffi#rustls-ffi-bindings).

use crate::rslice::rustls_str;
use libc::{c_void, size_t};
use libc::c_void;
use std::cell::RefCell;
use std::mem;
use std::sync::Arc;
Expand Down Expand Up @@ -436,14 +436,6 @@ where
F::to_arc(from)
}

impl CastPtr for size_t {
type RustType = size_t;
}

impl CastPtr for *const u8 {
type RustType = *const u8;
}

/// If the provided pointer is non-null, convert it to a reference.
/// Otherwise, return NullParameter, or an appropriate default (false, 0, NULL)
/// based on the context;
Expand Down
36 changes: 22 additions & 14 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use crate::session::{
SessionStoreGetCallback, SessionStorePutCallback,
};
use crate::{
ffi_panic_boundary, try_arc_from_ptr, try_box_from_ptr, try_mut_from_ptr, try_mut_slice,
try_ref_from_ptr, try_slice, userdata_get, ArcCastPtr, BoxCastPtr, CastConstPtr, CastPtr,
ffi_panic_boundary, try_arc_from_ptr, try_box_from_ptr, try_mut_from_ptr, try_ref_from_ptr,
try_slice, userdata_get, ArcCastPtr, BoxCastPtr, CastConstPtr, CastPtr,
};

/// A server config being constructed. A builder can be modified by,
Expand Down Expand Up @@ -371,8 +371,12 @@ pub extern "C" fn rustls_server_connection_get_sni_hostname(
) -> rustls_result {
ffi_panic_boundary! {
let conn: &Connection = try_ref_from_ptr!(conn);
let write_buf: &mut [u8] = try_mut_slice!(buf, count);
let out_n: &mut size_t = try_mut_from_ptr!(out_n);
if buf.is_null() {
return NullParameter
}
if out_n.is_null() {
return NullParameter
}
let server_connection = match conn.as_server() {
Some(s) => s,
_ => return rustls_result::InvalidParameter,
Expand All @@ -384,11 +388,16 @@ pub extern "C" fn rustls_server_connection_get_sni_hostname(
},
};
let len: usize = sni_hostname.len();
if len > write_buf.len() {
if len > count {
unsafe {
*out_n = 0
}
return rustls_result::InsufficientSize;
}
write_buf[..len].copy_from_slice(sni_hostname.as_bytes());
*out_n = len;
unsafe {
std::ptr::copy_nonoverlapping(sni_hostname.as_ptr(), buf, len);
*out_n = len;
}
rustls_result::Ok
}
}
Expand Down Expand Up @@ -625,17 +634,16 @@ pub extern "C" fn rustls_client_hello_select_certified_key(
ffi_panic_boundary! {
let hello = try_ref_from_ptr!(hello);
let schemes: Vec<SignatureScheme> = sigschemes(try_slice!(hello.signature_schemes.data, hello.signature_schemes.len));
let out_key: &mut *const rustls_certified_key = unsafe {
match out_key.as_mut() {
Some(out_key) => out_key,
None => return NullParameter,
}
};
if out_key.is_null() {
return NullParameter
}
let keys_ptrs: &[*const rustls_certified_key] = try_slice!(certified_keys, certified_keys_len);
for &key_ptr in keys_ptrs {
let key_ref: &CertifiedKey = try_ref_from_ptr!(key_ptr);
if key_ref.key.choose_scheme(&schemes).is_some() {
*out_key = key_ptr;
unsafe {
*out_key = key_ptr;
}
return rustls_result::Ok;
}
}
Expand Down