diff --git a/gen/src/builtin.rs b/gen/src/builtin.rs index 9c5f4fd78..0972dd8fb 100644 --- a/gen/src/builtin.rs +++ b/gen/src/builtin.rs @@ -148,8 +148,9 @@ pub(super) fn write(out: &mut OutFile) { if builtin.ptr_len { out.begin_block(Block::Namespace("repr")); + writeln!(out, "template "); writeln!(out, "struct PtrLen final {{"); - writeln!(out, " const void *ptr;"); + writeln!(out, " T *ptr;"); writeln!(out, " size_t len;"); writeln!(out, "}};"); out.end_block(Block::Namespace("repr")); @@ -163,7 +164,7 @@ pub(super) fn write(out: &mut OutFile) { if builtin.rust_str_new_unchecked { writeln!( out, - " static Str new_unchecked(repr::PtrLen repr) noexcept {{", + " static Str new_unchecked(repr::PtrLen repr) noexcept {{", ); writeln!(out, " Str str;"); writeln!(out, " str.ptr = static_cast(repr.ptr);"); @@ -172,8 +173,8 @@ pub(super) fn write(out: &mut OutFile) { writeln!(out, " }}"); } if builtin.rust_str_repr { - writeln!(out, " static repr::PtrLen repr(Str str) noexcept {{"); - writeln!(out, " return repr::PtrLen{{str.ptr, str.len}};"); + writeln!(out, " static repr::PtrLen repr(Str str) noexcept {{"); + writeln!(out, " return repr::PtrLen{{str.ptr, str.len}};"); writeln!(out, " }}"); } writeln!(out, "}};"); @@ -187,7 +188,7 @@ pub(super) fn write(out: &mut OutFile) { if builtin.rust_slice_new { writeln!( out, - " static Slice slice(repr::PtrLen repr) noexcept {{", + " static Slice slice(repr::PtrLen repr) noexcept {{", ); writeln!( out, @@ -196,11 +197,20 @@ pub(super) fn write(out: &mut OutFile) { writeln!(out, " }}"); } if builtin.rust_slice_repr { + writeln!(out, " template "); + writeln!( + out, + " static typename std::enable_if::value, repr::PtrLen>::type repr(Slice slice) noexcept {{", + ); + writeln!(out, " return repr::PtrLen{{slice.ptr, slice.len}};"); + writeln!(out, " }}"); + + writeln!(out, " template "); writeln!( out, - " static repr::PtrLen repr(Slice slice) noexcept {{", + " static typename std::enable_if::value, repr::PtrLen>::type repr(Slice slice) noexcept {{", ); - writeln!(out, " return repr::PtrLen{{slice.ptr, slice.len}};"); + writeln!(out, " return repr::PtrLen{{slice.ptr, slice.len}};"); writeln!(out, " }}"); } writeln!(out, "}};"); @@ -211,7 +221,7 @@ pub(super) fn write(out: &mut OutFile) { writeln!(out, "template <>"); writeln!(out, "class impl final {{"); writeln!(out, "public:"); - writeln!(out, " static Error error(repr::PtrLen repr) noexcept {{"); + writeln!(out, " static Error error(repr::PtrLen repr) noexcept {{"); writeln!(out, " Error error;"); writeln!(out, " error.msg = static_cast(repr.ptr);"); writeln!(out, " error.len = repr.len;"); diff --git a/gen/src/write.rs b/gen/src/write.rs index d99778605..fc397fb68 100644 --- a/gen/src/write.rs +++ b/gen/src/write.rs @@ -343,7 +343,7 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { } if efn.throws { out.builtin.ptr_len = true; - write!(out, "::rust::repr::PtrLen "); + write!(out, "::rust::repr::PtrLen "); } else { write_extern_return_type_space(out, &efn.ret); } @@ -417,7 +417,7 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { if efn.throws { out.builtin.ptr_len = true; out.builtin.trycatch = true; - writeln!(out, "::rust::repr::PtrLen throw$;"); + writeln!(out, "::rust::repr::PtrLen throw$;"); writeln!(out, " ::rust::behavior::trycatch("); writeln!(out, " [&] {{"); write!(out, " "); @@ -436,10 +436,14 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { out.builtin.rust_str_repr = true; write!(out, "::rust::impl<::rust::Str>::repr("); } - Some(Type::SliceRefU8(_)) if !indirect_return => { + Some(Type::SliceRefU8(ty)) if !indirect_return && ty.mutability.is_none() => { out.builtin.rust_slice_repr = true; write!(out, "::rust::impl<::rust::Slice>::repr(") } + Some(Type::SliceRefU8(_)) if !indirect_return => { + out.builtin.rust_slice_repr = true; + write!(out, "::rust::impl<::rust::Slice>::repr(") + } _ => {} } match &efn.receiver { @@ -474,12 +478,19 @@ fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) { out.builtin.unsafe_bitcopy = true; write_type(out, &arg.ty); write!(out, "(::rust::unsafe_bitcopy, *{})", arg.ident); - } else if let Type::SliceRefU8(_) = arg.ty { - write!( - out, - "::rust::Slice(static_cast({0}.ptr), {0}.len)", - arg.ident, - ); + } else if let Type::SliceRefU8(ref ty) = arg.ty { + match ty.mutability { + None => write!( + out, + "::rust::Slice(static_cast({0}.ptr), {0}.len)", + arg.ident, + ), + Some(_) => write!( + out, + "::rust::Slice(static_cast({0}.ptr), {0}.len)", + arg.ident, + ) + } } else if out.types.needs_indirect_abi(&arg.ty) { out.include.utility = true; write!(out, "::std::move(*{})", arg.ident); @@ -555,7 +566,7 @@ fn write_rust_function_decl_impl( out.next_section(); if sig.throws { out.builtin.ptr_len = true; - write!(out, "::rust::repr::PtrLen "); + write!(out, "::rust::repr::PtrLen "); } else { write_extern_return_type_space(out, &sig.ret); } @@ -701,7 +712,7 @@ fn write_rust_function_shim_impl( } if sig.throws { out.builtin.ptr_len = true; - write!(out, "::rust::repr::PtrLen error$ = "); + write!(out, "::rust::repr::PtrLen error$ = "); } write!(out, "{}(", invoke); let mut needs_comma = false; @@ -718,9 +729,13 @@ fn write_rust_function_shim_impl( out.builtin.rust_str_repr = true; write!(out, "::rust::impl<::rust::Str>::repr("); } - Type::SliceRefU8(_) => { + Type::SliceRefU8(ty) => { out.builtin.rust_slice_repr = true; - write!(out, "::rust::impl<::rust::Slice>::repr("); + if ty.mutability.is_none() { + write!(out, "::rust::impl<::rust::Slice>::repr("); + } else { + write!(out, "::rust::impl<::rust::Slice>::repr("); + } } ty if out.types.needs_indirect_abi(ty) => write!(out, "&"), _ => {} @@ -823,9 +838,14 @@ fn write_extern_return_type_space(out: &mut OutFile, ty: &Option) { write_type(out, &ty.inner); write!(out, " *"); } - Some(Type::Str(_)) | Some(Type::SliceRefU8(_)) => { + Some(Type::Str(ty)) | Some(Type::SliceRefU8(ty)) => { out.builtin.ptr_len = true; - write!(out, "::rust::repr::PtrLen "); + + if ty.mutable { + write!(out, "::rust::repr::PtrLen "); + } else { + write!(out, "::rust::repr::PtrLen "); + } } Some(ty) if out.types.needs_indirect_abi(ty) => write!(out, "void "), _ => write_return_type(out, ty), @@ -838,9 +858,13 @@ fn write_extern_arg(out: &mut OutFile, arg: &Var) { write_type_space(out, &ty.inner); write!(out, "*"); } - Type::Str(_) | Type::SliceRefU8(_) => { + Type::Str(ty) | Type::SliceRefU8(ty) => { out.builtin.ptr_len = true; - write!(out, "::rust::repr::PtrLen "); + if ty.mutable { + write!(out, "::rust::repr::PtrLen "); + } else { + write!(out, "::rust::repr::PtrLen "); + } } _ => write_type_space(out, &arg.ty), } @@ -890,9 +914,12 @@ fn write_type(out: &mut OutFile, ty: &Type) { Type::Str(_) => { write!(out, "::rust::Str"); } - Type::SliceRefU8(_) => { + Type::SliceRefU8(ty) if ty.mutability.is_none() => { write!(out, "::rust::Slice"); } + Type::SliceRefU8(_) => { + write!(out, "::rust::Slice"); + } Type::Fn(f) => { write!(out, "::rust::{}<", if f.throws { "TryFn" } else { "Fn" }); match &f.ret { diff --git a/include/cxx.h b/include/cxx.h index 4ed3302ba..d9210d0f5 100644 --- a/include/cxx.h +++ b/include/cxx.h @@ -89,21 +89,18 @@ class Str final { #ifndef CXXBRIDGE1_RUST_SLICE template class Slice final { - static_assert(std::is_const::value, - "&[T] needs to be written as rust::Slice in C++"); - public: Slice() noexcept; Slice(T *, size_t count) noexcept; - Slice &operator=(const Slice &) noexcept = default; + Slice &operator=(const Slice &that) noexcept = default; T *data() const noexcept; size_t size() const noexcept; size_t length() const noexcept; // Important in order for System V ABI to pass in registers. - Slice(const Slice &) noexcept = default; + Slice(const Slice &that) noexcept = default; ~Slice() noexcept = default; private: diff --git a/macro/src/expand.rs b/macro/src/expand.rs index f29c4cc1e..08c6a76ec 100644 --- a/macro/src/expand.rs +++ b/macro/src/expand.rs @@ -330,7 +330,8 @@ fn expand_cxx_function_shim(efn: &ExternFn, types: &Types) -> TokenStream { _ => quote!(#var), }, Type::Str(_) => quote!(::cxx::private::RustStr::from(#var)), - Type::SliceRefU8(_) => quote!(::cxx::private::RustSliceU8::from(#var)), + Type::SliceRefU8(ty) if ty.mutability.is_none() => quote!(::cxx::private::RustSliceU8::from(#var)), + Type::SliceRefU8(_) => quote!(::cxx::private::RustMutSliceU8::from(#var)), ty if types.needs_indirect_abi(ty) => quote!(#var.as_mut_ptr()), _ => quote!(#var), } @@ -423,7 +424,8 @@ fn expand_cxx_function_shim(efn: &ExternFn, types: &Types) -> TokenStream { _ => call, }, Type::Str(_) => quote!(#call.as_str()), - Type::SliceRefU8(_) => quote!(#call.as_slice()), + Type::SliceRefU8(ty) if ty.mutability.is_none() => quote!(#call.as_slice()), + Type::SliceRefU8(ty) if ty.mutability.is_some() => quote!(#call.as_mut_slice()), _ => call, }, }; @@ -610,7 +612,8 @@ fn expand_rust_function_shim_impl( _ => quote!(#ident), }, Type::Str(_) => quote!(#ident.as_str()), - Type::SliceRefU8(_) => quote!(#ident.as_slice()), + Type::SliceRefU8(ty) if ty.mutability.is_none() => quote!(#ident.as_slice()), + Type::SliceRefU8(ty) if ty.mutability.is_some() => quote!(#ident.as_mut_slice()), ty if types.needs_indirect_abi(ty) => quote!(::std::ptr::read(#ident)), _ => quote!(#ident), } @@ -654,7 +657,8 @@ fn expand_rust_function_shim_impl( _ => None, }, Type::Str(_) => Some(quote!(::cxx::private::RustStr::from)), - Type::SliceRefU8(_) => Some(quote!(::cxx::private::RustSliceU8::from)), + Type::SliceRefU8(ty) if ty.mutability.is_none() => Some(quote!(::cxx::private::RustSliceU8::from)), + Type::SliceRefU8(_) => Some(quote!(::cxx::private::RustMutSliceU8::from)), _ => None, }); @@ -1093,7 +1097,8 @@ fn expand_extern_type(ty: &Type, types: &Types, proper: bool) -> TokenStream { } } Type::Str(_) => quote!(::cxx::private::RustStr), - Type::SliceRefU8(_) => quote!(::cxx::private::RustSliceU8), + Type::SliceRefU8(ty) if ty.mutability.is_none() => quote!(::cxx::private::RustSliceU8), + Type::SliceRefU8(_) => quote!(::cxx::private::RustMutSliceU8), _ => quote!(#ty), } } diff --git a/src/lib.rs b/src/lib.rs index d921a54ba..78505d055 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -397,6 +397,7 @@ mod function; mod opaque; mod result; mod rust_sliceu8; +mod rust_mutsliceu8; mod rust_str; mod rust_string; mod rust_type; @@ -441,6 +442,7 @@ pub mod private { pub use crate::opaque::Opaque; pub use crate::result::{r#try, Result}; pub use crate::rust_sliceu8::RustSliceU8; + pub use crate::rust_mutsliceu8::RustMutSliceU8; pub use crate::rust_str::RustStr; pub use crate::rust_string::RustString; pub use crate::rust_type::RustType; diff --git a/src/rust_mutsliceu8.rs b/src/rust_mutsliceu8.rs new file mode 100644 index 000000000..2e290436e --- /dev/null +++ b/src/rust_mutsliceu8.rs @@ -0,0 +1,30 @@ +use core::mem; +use core::ptr::NonNull; +use core::slice; + +// Not necessarily ABI compatible with &mut [u8]. Codegen performs the translation. +#[repr(C)] +#[derive(Copy, Clone)] +pub struct RustMutSliceU8 { + pub(crate) ptr: NonNull, + pub(crate) len: usize, +} + +impl RustMutSliceU8 { + pub fn from(s: &mut [u8]) -> Self { + let len = s.len(); + RustMutSliceU8 { + ptr: NonNull::from(s).cast::(), + len, + } + } + + pub unsafe fn as_mut_slice<'a>(self) -> &'a mut [u8] { + slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) + } +} + +const_assert_eq!( + mem::size_of::>(), + mem::size_of::(), +); diff --git a/syntax/check.rs b/syntax/check.rs index 2e94f4db2..9d656b690 100644 --- a/syntax/check.rs +++ b/syntax/check.rs @@ -178,7 +178,7 @@ fn check_type_ref(cx: &mut Check, ty: &Ref) { } fn check_type_slice(cx: &mut Check, ty: &Slice) { - cx.error(ty, "only &[u8] is supported so far, not other slice types"); + cx.error(ty, "only &[u8] and &mut [u8] are supported so far, not other slice types"); } fn check_type_fn(cx: &mut Check, ty: &Signature) { @@ -489,7 +489,8 @@ fn describe(cx: &mut Check, ty: &Type) -> String { Type::Str(_) => "&str".to_owned(), Type::CxxVector(_) => "C++ vector".to_owned(), Type::Slice(_) => "slice".to_owned(), - Type::SliceRefU8(_) => "&[u8]".to_owned(), + Type::SliceRefU8(ty) if ty.mutability.is_none() => "&[u8]".to_owned(), + Type::SliceRefU8(_) => "&mut [u8]".to_owned(), Type::Fn(_) => "function pointer".to_owned(), Type::Void(_) => "()".to_owned(), } diff --git a/syntax/parse.rs b/syntax/parse.rs index 38b3dccf4..c6d64804d 100644 --- a/syntax/parse.rs +++ b/syntax/parse.rs @@ -658,7 +658,7 @@ fn parse_type_reference(ty: &TypeReference, namespace: &Namespace) -> Result match &slice.inner { - Type::Ident(ident) if ident.rust == U8 && ty.mutability.is_none() => Type::SliceRefU8, + Type::Ident(ident) if ident.rust == U8 => Type::SliceRefU8, _ => Type::Ref, }, _ => Type::Ref, diff --git a/tests/cxx_gen.rs b/tests/cxx_gen.rs index d33092c74..e4299a3fd 100644 --- a/tests/cxx_gen.rs +++ b/tests/cxx_gen.rs @@ -20,7 +20,7 @@ fn test_extern_c_function() { let output = str::from_utf8(&generated.implementation).unwrap(); // To avoid continual breakage we won't test every byte. // Let's look for the major features. - assert!(output.contains("void cxxbridge1$do_cpp_thing(::rust::repr::PtrLen foo)")); + assert!(output.contains("void cxxbridge1$do_cpp_thing(::rust::repr::PtrLen foo)")); } #[test] @@ -30,5 +30,5 @@ fn test_impl_annotation() { let source = BRIDGE0.parse().unwrap(); let generated = generate_header_and_cc(source, &opt).unwrap(); let output = str::from_utf8(&generated.implementation).unwrap(); - assert!(output.contains("ANNOTATION void cxxbridge1$do_cpp_thing(::rust::repr::PtrLen foo)")); + assert!(output.contains("ANNOTATION void cxxbridge1$do_cpp_thing(::rust::repr::PtrLen foo)")); } diff --git a/tests/ffi/lib.rs b/tests/ffi/lib.rs index 032adc47d..c6967deed 100644 --- a/tests/ffi/lib.rs +++ b/tests/ffi/lib.rs @@ -129,6 +129,7 @@ pub mod ffi { fn c_return_mut(shared: &mut Shared) -> &mut usize; fn c_return_str(shared: &Shared) -> &str; fn c_return_sliceu8(shared: &Shared) -> &[u8]; + fn c_return_mutsliceu8(shared: &Shared) -> &mut [u8]; fn c_return_rust_string() -> String; fn c_return_unique_ptr_string() -> UniquePtr; fn c_return_unique_ptr_vector_u8() -> UniquePtr>; @@ -157,6 +158,7 @@ pub mod ffi { fn c_take_ref_c(c: &C); fn c_take_str(s: &str); fn c_take_sliceu8(s: &[u8]); + fn c_take_mutsliceu8(s: &mut [u8]); fn c_take_rust_string(s: String); fn c_take_unique_ptr_string(s: UniquePtr); fn c_take_unique_ptr_vector_u8(v: UniquePtr>); @@ -192,6 +194,7 @@ pub mod ffi { fn c_try_return_ref(s: &String) -> Result<&String>; fn c_try_return_str(s: &str) -> Result<&str>; fn c_try_return_sliceu8(s: &[u8]) -> Result<&[u8]>; + fn c_try_return_mutsliceu8(s: &mut [u8]) -> Result<&mut [u8]>; fn c_try_return_rust_string() -> Result; fn c_try_return_unique_ptr_string() -> Result>; fn c_try_return_rust_vec() -> Result>; @@ -256,6 +259,7 @@ pub mod ffi { fn r_take_ref_c(c: &C); fn r_take_str(s: &str); fn r_take_sliceu8(s: &[u8]); + fn r_take_mutsliceu8(s: &mut [u8]); fn r_take_rust_string(s: String); fn r_take_unique_ptr_string(s: UniquePtr); fn r_take_ref_vector(v: &CxxVector); @@ -448,6 +452,11 @@ fn r_take_sliceu8(s: &[u8]) { assert_eq!(std::str::from_utf8(s).unwrap(), "2020\0"); } +fn r_take_mutsliceu8(s: &mut [u8]) { + assert_eq!(s.len(), 5); + assert_eq!(std::str::from_utf8(s).unwrap(), "2020\0"); +} + fn r_take_unique_ptr_string(s: UniquePtr) { assert_eq!(s.as_ref().unwrap().to_str().unwrap(), "2020"); } diff --git a/tests/ffi/tests.h b/tests/ffi/tests.h index 5acf469db..ba56a000e 100644 --- a/tests/ffi/tests.h +++ b/tests/ffi/tests.h @@ -84,6 +84,7 @@ const size_t &c_return_nested_ns_ref(const ::A::B::ABShared &shared); size_t &c_return_mut(Shared &shared); rust::Str c_return_str(const Shared &shared); rust::Slice c_return_sliceu8(const Shared &shared); +rust::Slice c_return_mutsliceu8(const Shared &shared); rust::String c_return_rust_string(); std::unique_ptr c_return_unique_ptr_string(); std::unique_ptr> c_return_unique_ptr_vector_u8(); @@ -114,6 +115,7 @@ void c_take_ref_c(const C &c); void c_take_ref_ns_c(const ::H::H &h); void c_take_str(rust::Str s); void c_take_sliceu8(rust::Slice s); +void c_take_mutsliceu8(rust::Slice s); void c_take_rust_string(rust::String s); void c_take_unique_ptr_string(std::unique_ptr s); void c_take_unique_ptr_vector_u8(std::unique_ptr> v); @@ -148,6 +150,7 @@ rust::Box c_try_return_box(); const rust::String &c_try_return_ref(const rust::String &); rust::Str c_try_return_str(rust::Str); rust::Slice c_try_return_sliceu8(rust::Slice); +rust::Slice c_try_return_mutsliceu8(rust::Slice); rust::String c_try_return_rust_string(); std::unique_ptr c_try_return_unique_ptr_string(); rust::Vec c_try_return_rust_vec();