diff --git a/examples/webrtc.rs b/examples/webrtc.rs index 5a5e909..317a7f5 100644 --- a/examples/webrtc.rs +++ b/examples/webrtc.rs @@ -44,7 +44,7 @@ async fn main() -> Result<()> { record.args() ) }) - .filter(None, log::LevelFilter::Debug) + .filter(Some("hbb_common"), log::LevelFilter::Debug) .init(); } @@ -56,17 +56,23 @@ async fn main() -> Result<()> { let webrtc_stream = hbb_common::webrtc::WebRTCStream::new(&remote_endpoint, 30000).await?; // Print the offer to be sent to the other peer - webrtc_stream.get_local_endpoint().await; + let local_endpoint = webrtc_stream.get_local_endpoint().await?; if remote_endpoint.is_empty() { + println!(); // Wait for the answer to be pasted - println!("Wait for the answer to be pasted"); + println!( + "Start new terminal run: \n{} \ncopy remote endpoint and paste here", + format!("cargo r --example webrtc -- --offer {}", local_endpoint) + ); // readline blocking let line = std::io::stdin() .lines() .next() .ok_or_else(|| anyhow::anyhow!("No input received"))??; webrtc_stream.set_remote_endpoint(&line).await?; + } else { + println!("Copy local endpoint and paste to the other peer: \n{}", local_endpoint); } let s1 = hbb_common::Stream::WebRTC(webrtc_stream.clone()); @@ -93,16 +99,20 @@ async fn main() -> Result<()> { async fn read_loop(mut stream: hbb_common::Stream) -> Result<()> { loop { let Some(res) = stream.next().await else { - println!("Datachannel closed; Exit the read_loop"); + println!("WebRTC stream closed; Exit the read_loop"); return Ok(()); }; - println!("Message from DataChannel: {}", + if res.is_err() { + println!("WebRTC stream read error: {}; Exit the read_loop", res.err().unwrap()); + return Ok(()); + } + println!("Message from stream: {}", String::from_utf8(res.unwrap().to_vec())? ); } } -// write_loop shows how to write to the datachannel directly +// write_loop shows how to write to the webrtc stream directly async fn write_loop(mut stream: hbb_common::Stream) -> Result<()> { let mut result = Result::<()>::Ok(()); while result.is_ok() { @@ -112,12 +122,12 @@ async fn write_loop(mut stream: hbb_common::Stream) -> Result<()> { tokio::select! { _ = timeout.as_mut() =>{ let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = stream.send_bytes(Bytes::from(message)).await; + result = stream.send_bytes(Bytes::from(message.clone())).await; + println!("Sent '{message}' {}", result.is_ok()); } }; } - println!("Datachannel write not ok; Exit the write_loop"); + println!("WebRTC stream write failed; Exit the write_loop"); Ok(()) } \ No newline at end of file diff --git a/src/webrtc.rs b/src/webrtc.rs index b276e77..3a3ed01 100644 --- a/src/webrtc.rs +++ b/src/webrtc.rs @@ -7,7 +7,6 @@ use std::collections::HashMap; use webrtc::api::APIBuilder; use webrtc::api::setting_engine::SettingEngine; use webrtc::data_channel::RTCDataChannel; -use webrtc::data_channel::data_channel_state::RTCDataChannelState; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::peer_connection::RTCPeerConnection; use webrtc::peer_connection::configuration::RTCConfiguration; @@ -30,7 +29,7 @@ use tokio::sync::Mutex; pub struct WebRTCStream { pc: Arc, - stream: Arc, + stream: Arc>>, state_notify: watch::Receiver, send_timeout: u64, } @@ -77,11 +76,24 @@ impl WebRTCStream { format!("webrtc://{}", encoded_sdp) } + async fn get_key_for_peer(pc: &Arc) -> String { + if let Some(local_desc) = pc.local_description().await { + if local_desc.sdp_type != webrtc::peer_connection::sdp::sdp_type::RTCSdpType::Offer { + let Some(remote_desc) = pc.remote_description().await else { + return "".into(); + }; + return serde_json::to_string(&remote_desc).unwrap_or_default(); + } + return serde_json::to_string(&local_desc).unwrap_or_default(); + } + "".into() + } + pub async fn new( remote_endpoint: &str, ms_timeout: u64, ) -> ResultType { - log::debug!("New webrtc stream with endpoint: {}", remote_endpoint); + log::debug!("New webrtc stream to endpoint: {}", remote_endpoint); let remote_offer = if remote_endpoint.is_empty() { "".into() } else { @@ -116,35 +128,61 @@ impl WebRTCStream { }; let (notify_tx, notify_rx) = watch::channel(false); - let on_open_notify = notify_tx.clone(); + let dc_open_notify = notify_tx.clone(); // Create a new RTCPeerConnection - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - let data_channel = peer_connection.create_data_channel("bootstrap", None).await?; - data_channel.on_open(Box::new(move || { - log::debug!("Data channel bootstrap open."); - let _ = on_open_notify.send(true); + let pc = Arc::new(api.new_peer_connection(config).await?); + let bootstrap_dc = if remote_offer.is_empty() { + // Create a data channel with label "bootstrap" + pc.create_data_channel("bootstrap", None).await? + } else { + // Wait for the data channel to be created by the remote peer + // Here we create a dummy data channel to satisfy the type system + Arc::new(RTCDataChannel::default()) + }; + bootstrap_dc.on_open(Box::new(move || { + log::debug!("Local data channel bootstrap open."); + let _ = dc_open_notify.send(true); Box::pin(async {}) })); + let stream = Arc::new(Mutex::new(bootstrap_dc.clone())); + // This will notify you when the peer has connected/disconnected let on_connection_notify = notify_tx.clone(); - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - log::debug!("Peer Connection State has changed: {}", s); - if s == RTCPeerConnectionState::Disconnected { - let _ = on_connection_notify.send(true); - } - - // TODO clear SESSIONS entry? - Box::pin(async {}) + let stream_for_close = stream.clone(); + let pc_for_close = pc.clone(); + pc.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { + let stream_for_close2 = stream_for_close.clone(); + let on_connection_notify2 = on_connection_notify.clone(); + let pc_for_close2 = pc_for_close.clone(); + Box::pin(async move { + log::debug!("Peer connection state : {}", s); + if s == RTCPeerConnectionState::Disconnected { + let _ = on_connection_notify2.send(true); + log::debug!("WebRTC session closing due to disconnected"); + let _ = stream_for_close2.lock().await.close().await; + log::debug!("WebRTC session stream closed"); + } else if s == RTCPeerConnectionState::Failed || s == RTCPeerConnectionState::Closed { + let mut lock = SESSIONS.lock().await; + let key = WebRTCStream::get_key_for_peer(&pc_for_close2).await; + log::debug!("WebRTC session removing key from cache: {}", key); + lock.remove(&key); + } + }) })); // Register data channel creation handling - let on_open_notify2 = notify_tx.clone(); - peer_connection.on_data_channel(Box::new(move |dc: Arc| { + let remote_dc_open_notify = notify_tx.clone(); + let stream_for_dc = stream.clone(); + pc.on_data_channel(Box::new(move |dc: Arc| { let d_label = dc.label().to_owned(); - log::debug!("Remote data channel {}", d_label); - let notify = on_open_notify2.clone(); + let notify = remote_dc_open_notify.clone(); + let stream_for_dc_clone = stream_for_dc.clone(); + log::debug!("Remote data channel {} ready", d_label); Box::pin(async move { + let mut stream_lock = stream_for_dc_clone.lock().await; + *stream_lock = dc.clone(); + drop(stream_lock); dc.on_open(Box::new(move || { let _ = notify.send(true); Box::pin(async {}) @@ -152,30 +190,28 @@ impl WebRTCStream { }) })); + // process offer/answer if remote_offer.is_empty() { - let sdp = peer_connection.create_offer(None).await?; - let mut gather_complete = peer_connection.gathering_complete_promise().await; - peer_connection.set_local_description(sdp.clone()).await?; + let sdp = pc.create_offer(None).await?; + let mut gather_complete = pc.gathering_complete_promise().await; + pc.set_local_description(sdp.clone()).await?; let _ = gather_complete.recv().await; - let final_sdp = peer_connection.local_description().await.ok_or_else(|| { - Error::new(ErrorKind::Other, "Failed to get local description after gathering") - })?; - key = serde_json::to_string(&final_sdp).unwrap_or_default(); + key = Self::get_key_for_peer(&pc).await; log::debug!("Start webrtc with local: {}", key); } else { let sdp = serde_json::from_str::(&remote_offer)?; - peer_connection.set_remote_description(sdp).await?; - let answer = peer_connection.create_answer(None).await?; - let mut gather_complete = peer_connection.gathering_complete_promise().await; - peer_connection.set_local_description(answer).await?; + pc.set_remote_description(sdp).await?; + let answer = pc.create_answer(None).await?; + let mut gather_complete = pc.gathering_complete_promise().await; + pc.set_local_description(answer).await?; let _ = gather_complete.recv().await; log::debug!("Start webrtc with remote: {}", remote_offer); } let webrtc_stream = WebRTCStream { - pc: peer_connection, - stream: data_channel, + pc, + stream, state_notify: notify_rx, send_timeout: ms_timeout, }; @@ -185,14 +221,13 @@ impl WebRTCStream { } #[inline] - pub async fn get_local_endpoint(&self) -> Option { + pub async fn get_local_endpoint(&self) -> ResultType { if let Some(local_desc) = self.pc.local_description().await { let sdp = serde_json::to_string(&local_desc).unwrap_or_default(); let endpoint = Self::sdp_to_endpoint(&sdp); - log::debug!("WebRTC get local endpoint: {}", endpoint); - Some(endpoint) + Ok(endpoint) } else { - None + Err(anyhow::anyhow!("Local description is not set")) } } @@ -240,29 +275,33 @@ impl WebRTCStream { self.send_bytes(Bytes::from(msg)).await } - pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + #[inline] + async fn wait_for_connect_result(&mut self) { + if *self.state_notify.borrow() { + return; + } let _ = self.state_notify.changed().await; - self.stream.send(&bytes).await?; + } + + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + self.wait_for_connect_result().await; + let stream = self.stream.lock().await.clone(); + stream.send(&bytes).await?; Ok(()) } #[inline] pub async fn next(&mut self) -> Option> { - // wait for connected or disconnected - let _ = self.state_notify.changed().await; - if self.stream.ready_state() != RTCDataChannelState::Open { - return Some(Err(Error::new( - ErrorKind::Other, - "data channel is closed", - ))); - } + self.wait_for_connect_result().await; + let stream = self.stream.lock().await.clone(); // TODO reuse buffer? let mut buffer = BytesMut::zeroed(DATA_CHANNEL_BUFFER_SIZE as usize); - let dc = self.stream.detach().await.ok()?; + let dc = stream.detach().await.ok()?; let n = match dc.read(&mut buffer).await { Ok(n) => n, Err(err) => { + self.pc.close().await.ok(); return Some(Err(Error::new( ErrorKind::Other, format!("data channel read error: {}", err), @@ -270,12 +309,12 @@ impl WebRTCStream { } }; if n == 0 { + self.pc.close().await.ok(); return Some(Err(Error::new( ErrorKind::Other, "data channel read exited with 0 bytes", ))); } - log::debug!("WebRTCStream read {} bytes", n); buffer.truncate(n); Some(Ok(buffer)) }