diff --git a/webrtc/src/peer_connection/mod.rs b/webrtc/src/peer_connection/mod.rs index 779e6b58f..698b95340 100644 --- a/webrtc/src/peer_connection/mod.rs +++ b/webrtc/src/peer_connection/mod.rs @@ -473,7 +473,7 @@ impl RTCPeerConnection { None => return true, // doesn't contain a single a=msid line }; - let sender = t.sender(); + let sender = t.sender().await; // (...)or the number of MSIDs from the a=msid lines in this m= section, // or the MSID values themselves, differ from what is in // transceiver.sender.[[AssociatedMediaStreamIds]], return true. @@ -1595,7 +1595,7 @@ impl RTCPeerConnection { pub(crate) async fn start_rtp_senders(&self) -> Result<()> { let current_transceivers = self.internal.rtp_transceivers.lock().await; for transceiver in &*current_transceivers { - let sender = transceiver.sender(); + let sender = transceiver.sender().await; if sender.is_negotiated() && !sender.has_sent() { sender.send(&sender.get_parameters().await).await?; } @@ -1653,7 +1653,7 @@ impl RTCPeerConnection { let mut senders = vec![]; let rtp_transceivers = self.internal.rtp_transceivers.lock().await; for transceiver in &*rtp_transceivers { - let sender = transceiver.sender(); + let sender = transceiver.sender().await; senders.push(sender); } senders @@ -1664,7 +1664,7 @@ impl RTCPeerConnection { let mut receivers = vec![]; let rtp_transceivers = self.internal.rtp_transceivers.lock().await; for transceiver in &*rtp_transceivers { - receivers.push(transceiver.receiver()); + receivers.push(transceiver.receiver().await); } receivers } @@ -1688,7 +1688,7 @@ impl RTCPeerConnection { let rtp_transceivers = self.internal.rtp_transceivers.lock().await; for t in &*rtp_transceivers { if !t.stopped.load(Ordering::SeqCst) && t.kind == track.kind() { - let sender = t.sender(); + let sender = t.sender().await; if sender.track().await.is_none() { if let Err(err) = sender.replace_track(Some(track)).await { let _ = sender.stop().await; @@ -1715,7 +1715,7 @@ impl RTCPeerConnection { .add_rtp_transceiver(Arc::clone(&transceiver)) .await; - Ok(transceiver.sender()) + Ok(transceiver.sender().await) } /// remove_track removes a Track from the PeerConnection @@ -1728,7 +1728,7 @@ impl RTCPeerConnection { { let rtp_transceivers = self.internal.rtp_transceivers.lock().await; for t in &*rtp_transceivers { - if t.sender().id == sender.id { + if t.sender().await.id == sender.id { if sender.track().await.is_none() { return Ok(()); } diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index 0c4e5d113..a77dc2f3c 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -167,7 +167,7 @@ impl PeerConnectionInternal { self.undeclared_media_processor(); } else { for t in ¤t_transceivers { - let receiver = t.receiver(); + let receiver = t.receiver().await; let tracks = receiver.tracks().await; if tracks.is_empty() { continue; @@ -217,7 +217,7 @@ impl PeerConnectionInternal { Arc::clone(&self.media_engine), interceptor, )); - t.set_receiver(receiver); + t.set_receiver(receiver).await; } } @@ -338,7 +338,7 @@ impl PeerConnectionInternal { for incoming_track in incoming_tracks { // If we already have a TrackRemote for a given SSRC don't handle it again for t in local_transceivers { - let receiver = t.receiver(); + let receiver = t.receiver().await; for track in receiver.tracks().await { for ssrc in &incoming_track.ssrcs { if *ssrc == track.ssrc() { @@ -364,7 +364,7 @@ impl PeerConnectionInternal { continue; } - let receiver = t.receiver(); + let receiver = t.receiver().await; if receiver.have_received().await { continue; } @@ -667,7 +667,7 @@ impl PeerConnectionInternal { } // TODO: This is dubious because of rollbacks. - t.sender().set_negotiated(); + t.sender().await.set_negotiated(); media_sections.push(MediaSection { id: t.mid().unwrap().0.to_string(), transceivers: vec![Arc::clone(t)], @@ -756,7 +756,7 @@ impl PeerConnectionInternal { } if let Some(t) = find_by_mid(mid_value, &mut local_transceivers).await { - t.sender().set_negotiated(); + t.sender().await.set_negotiated(); let media_transceivers = vec![t]; // NB: The below could use `then_some`, but with our current MSRV @@ -781,7 +781,7 @@ impl PeerConnectionInternal { // If we are offering also include unmatched local transceivers if include_unmatched { for t in &local_transceivers { - t.sender().set_negotiated(); + t.sender().await.set_negotiated(); media_sections.push(MediaSection { id: t.mid().unwrap().0.to_string(), transceivers: vec![Arc::clone(t)], @@ -887,7 +887,7 @@ impl PeerConnectionInternal { ) .await?; - let receiver = t.receiver(); + let receiver = t.receiver().await; PeerConnectionInternal::start_receiver( self.setting_engine.get_receive_mtu(), &incoming, @@ -1008,7 +1008,7 @@ impl PeerConnectionInternal { continue; } - let receiver = t.receiver(); + let receiver = t.receiver().await; if !rsid.is_empty() { return receiver @@ -1210,7 +1210,7 @@ impl PeerConnectionInternal { } let mut track_infos = vec![]; for transeiver in transceivers { - let receiver = transeiver.receiver(); + let receiver = transeiver.receiver().await; if let Some(mid) = transeiver.mid() { let tracks = receiver.tracks().await; @@ -1335,7 +1335,7 @@ impl PeerConnectionInternal { } let mut track_infos = vec![]; for transceiver in transceivers { - let sender = transceiver.sender(); + let sender = transceiver.sender().await; let mid = match transceiver.mid() { Some(mid) => mid, diff --git a/webrtc/src/peer_connection/peer_connection_test.rs b/webrtc/src/peer_connection/peer_connection_test.rs index ca618d097..ea9ca4c4e 100644 --- a/webrtc/src/peer_connection/peer_connection_test.rs +++ b/webrtc/src/peer_connection/peer_connection_test.rs @@ -1,14 +1,22 @@ use super::*; +use crate::api::interceptor_registry::register_default_interceptors; +use crate::api::media_engine::MediaEngine; use crate::api::media_engine::MIME_TYPE_VP8; use crate::api::APIBuilder; use crate::ice_transport::ice_candidate_pair::RTCIceCandidatePair; +use crate::ice_transport::ice_server::RTCIceServer; +use crate::peer_connection::configuration::RTCConfiguration; use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use crate::stats::StatsReportType; use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; +use crate::Error; +use interceptor::registry::Registry; + use bytes::Bytes; use media::Sample; use std::sync::atomic::AtomicU32; +use std::sync::Arc; use tokio::time::Duration; use util::vnet::net::{Net, NetConfig}; use util::vnet::router::{Router, RouterConfig}; @@ -374,3 +382,43 @@ async fn test_get_stats() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_peer_connection_close_is_send() -> Result<()> { + let handle = tokio::spawn(async move { peer().await }); + tokio::join!(handle).0.unwrap() +} + +async fn peer() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let mut registry = Registry::new(); + registry = register_default_interceptors(registry, &mut m)?; + let api = APIBuilder::new() + .with_media_engine(m) + .with_interceptor_registry(registry) + .build(); + + let config = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec!["stun:stun.l.google.com:19302".to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let peer_connection = Arc::new(api.new_peer_connection(config).await?); + + let offer = peer_connection.create_offer(None).await?; + let mut gather_complete = peer_connection.gathering_complete_promise().await; + peer_connection.set_local_description(offer).await?; + let _ = gather_complete.recv().await; + + if peer_connection.local_description().await.is_some() { + //TODO? + } + + peer_connection.close().await?; + + Ok(()) +} diff --git a/webrtc/src/peer_connection/sdp/mod.rs b/webrtc/src/peer_connection/sdp/mod.rs index 2b59d9617..b98836378 100644 --- a/webrtc/src/peer_connection/sdp/mod.rs +++ b/webrtc/src/peer_connection/sdp/mod.rs @@ -463,7 +463,7 @@ pub(crate) async fn add_transceiver_sdp( } if codecs.is_empty() { // If we are sender and we have no codecs throw an error early - if t.sender().track().await.is_some() { + if t.sender().await.track().await.is_some() { return Err(Error::ErrSenderWithNoCodecs); } @@ -530,7 +530,7 @@ pub(crate) async fn add_transceiver_sdp( } for mt in transceivers { - let sender = mt.sender(); + let sender = mt.sender().await; if let Some(track) = sender.track().await { media = media.with_media_source( sender.ssrc, diff --git a/webrtc/src/peer_connection/sdp/sdp_test.rs b/webrtc/src/peer_connection/sdp/sdp_test.rs index 666a2dc4e..52c63be12 100644 --- a/webrtc/src/peer_connection/sdp/sdp_test.rs +++ b/webrtc/src/peer_connection/sdp/sdp_test.rs @@ -642,17 +642,19 @@ async fn test_media_description_fingerprints() -> Result<()> { "video".to_owned(), "webrtc-rs".to_owned(), )); - media[i].transceivers[0].set_sender(Arc::new( - RTCRtpSender::new( - api.setting_engine.get_receive_mtu(), - Some(track), - Arc::new(RTCDtlsTransport::default()), - Arc::clone(&api.media_engine), - Arc::clone(&interceptor), - false, - ) - .await, - )); + media[i].transceivers[0] + .set_sender(Arc::new( + RTCRtpSender::new( + api.setting_engine.get_receive_mtu(), + Some(track), + Arc::new(RTCDtlsTransport::default()), + Arc::clone(&api.media_engine), + Arc::clone(&interceptor), + false, + ) + .await, + )) + .await; media[i].transceivers[0].set_direction_internal(RTCRtpTransceiverDirection::Sendonly); } diff --git a/webrtc/src/rtp_transceiver/mod.rs b/webrtc/src/rtp_transceiver/mod.rs index c019d3f04..4b9616f92 100644 --- a/webrtc/src/rtp_transceiver/mod.rs +++ b/webrtc/src/rtp_transceiver/mod.rs @@ -32,7 +32,6 @@ pub mod rtp_receiver; pub mod rtp_sender; pub mod rtp_transceiver_direction; pub(crate) mod srtp_writer_future; -use util::sync::Mutex as SyncMutex; /// SSRC represents a synchronization source /// A synchronization source is a randomly chosen @@ -176,9 +175,9 @@ pub type TriggerNegotiationNeededFnOption = /// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. pub struct RTCRtpTransceiver { - mid: OnceCell, //atomic.Value - sender: SyncMutex>, //atomic.Value - receiver: SyncMutex>, //atomic.Value + mid: OnceCell, //atomic.Value + sender: Mutex>, //atomic.Value + receiver: Mutex>, //atomic.Value direction: AtomicU8, //RTPTransceiverDirection current_direction: AtomicU8, //RTPTransceiverDirection @@ -208,8 +207,8 @@ impl RTCRtpTransceiver { let t = Arc::new(RTCRtpTransceiver { mid: OnceCell::new(), - sender: SyncMutex::new(sender), - receiver: SyncMutex::new(receiver), + sender: Mutex::new(sender), + receiver: Mutex::new(receiver), direction: AtomicU8::new(direction as u8), current_direction: AtomicU8::new(RTCRtpTransceiverDirection::Unspecified as u8), @@ -220,7 +219,9 @@ impl RTCRtpTransceiver { media_engine, trigger_negotiation_needed: Mutex::new(trigger_negotiation_needed), }); - t.sender().set_rtp_transceiver(Some(Arc::downgrade(&t))); + t.sender() + .await + .set_rtp_transceiver(Some(Arc::downgrade(&t))); t } @@ -250,8 +251,8 @@ impl RTCRtpTransceiver { } /// sender returns the RTPTransceiver's RTPSender if it has one - pub fn sender(&self) -> Arc { - let sender = self.sender.lock(); + pub async fn sender(&self) -> Arc { + let sender = self.sender.lock().await; sender.clone() } @@ -261,33 +262,33 @@ impl RTCRtpTransceiver { sender: Arc, track: Option>, ) -> Result<()> { - self.set_sender(sender); + self.set_sender(sender).await; self.set_sending_track(track).await } - pub fn set_sender(self: &Arc, s: Arc) { + pub async fn set_sender(self: &Arc, s: Arc) { s.set_rtp_transceiver(Some(Arc::downgrade(self))); - let prev_sender = self.sender(); + let prev_sender = self.sender().await; prev_sender.set_rtp_transceiver(None); { - let mut sender = self.sender.lock(); + let mut sender = self.sender.lock().await; *sender = s; } } /// receiver returns the RTPTransceiver's RTPReceiver if it has one - pub fn receiver(&self) -> Arc { - let receiver = self.receiver.lock(); + pub async fn receiver(&self) -> Arc { + let receiver = self.receiver.lock().await; receiver.clone() } - pub(crate) fn set_receiver(&self, r: Arc) { + pub(crate) async fn set_receiver(&self, r: Arc) { r.set_transceiver_codecs(Some(Arc::clone(&self.codecs))); { - let mut receiver = self.receiver.lock(); + let mut receiver = self.receiver.lock().await; (*receiver).set_transceiver_codecs(None); *receiver = r; @@ -398,7 +399,7 @@ impl RTCRtpTransceiver { } { - let receiver = self.receiver.lock().clone(); + let receiver = self.receiver.lock().await; let pause_receiver = !current_direction.has_recv(); if pause_receiver { @@ -410,7 +411,7 @@ impl RTCRtpTransceiver { let pause_sender = !current_direction.has_send(); { - let sender = &*self.sender.lock(); + let sender = &*self.sender.lock().await; sender.set_paused(pause_sender); } @@ -426,11 +427,11 @@ impl RTCRtpTransceiver { self.stopped.store(true, Ordering::SeqCst); { - let sender = self.sender.lock(); + let sender = self.sender.lock().await; sender.stop().await?; } { - let r = self.receiver.lock(); + let r = self.receiver.lock().await; r.stop().await?; } @@ -445,7 +446,7 @@ impl RTCRtpTransceiver { ) -> Result<()> { let track_is_none = track.is_none(); { - let sender = self.sender.lock().clone(); + let sender = self.sender.lock().await; sender.replace_track(track).await?; } diff --git a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs index 676ef4e3a..3448f1f48 100644 --- a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs @@ -133,7 +133,7 @@ async fn test_rtp_sender_get_parameters() -> Result<()> { signal_pair(&mut offerer, &mut answerer).await?; - let sender = rtp_transceiver.sender(); + let sender = rtp_transceiver.sender().await; let parameters = sender.get_parameters().await; assert_ne!(0, parameters.rtp_parameters.codecs.len()); assert_eq!(1, parameters.encodings.len());