From 5dcfea1ee4950c1076fffaec7de86ce6c7fc5e8f Mon Sep 17 00:00:00 2001 From: lc Date: Fri, 14 Nov 2025 16:02:51 +0800 Subject: [PATCH] support send_timeout --- examples/webrtc.rs | 16 +++++-- src/webrtc.rs | 113 ++++++++++++++++++++++++++++----------------- 2 files changed, 81 insertions(+), 48 deletions(-) diff --git a/examples/webrtc.rs b/examples/webrtc.rs index 40d4cdf..e9af423 100644 --- a/examples/webrtc.rs +++ b/examples/webrtc.rs @@ -8,10 +8,10 @@ mod webrtc_dummy; use crate::webrtc_dummy::WebRTCStream; use std::io::Write; -use bytes::Bytes; -use clap::{Arg, Command}; use anyhow::Result; +use bytes::Bytes; +use clap::{Arg, Command}; use tokio::time::Duration; use webrtc::peer_connection::math_rand_alpha; @@ -75,7 +75,10 @@ async fn main() -> Result<()> { // Wait for the answer to be pasted println!( "Start new terminal run: \n{} \ncopy remote endpoint and paste here", - format!("cargo r --features webrtc --example webrtc -- --offer {}", local_endpoint) + format!( + "cargo r --features webrtc --example webrtc -- --offer {}", + local_endpoint + ) ); // readline blocking let line = std::io::stdin() @@ -84,7 +87,10 @@ async fn main() -> Result<()> { .ok_or_else(|| anyhow::anyhow!("No input received"))??; webrtc_stream.set_remote_endpoint(&line).await?; } else { - println!("Copy local endpoint and paste to the other peer: \n{}", local_endpoint); + println!( + "Copy local endpoint and paste to the other peer: \n{}", + local_endpoint + ); } let s1 = webrtc_stream.clone(); @@ -144,4 +150,4 @@ async fn write_loop(mut stream: WebRTCStream) -> Result<()> { println!("WebRTC stream write failed; Exit the write_loop"); Ok(()) -} \ No newline at end of file +} diff --git a/src/webrtc.rs b/src/webrtc.rs index 6046b7e..8ffb91b 100644 --- a/src/webrtc.rs +++ b/src/webrtc.rs @@ -1,31 +1,29 @@ -use std::sync::Arc; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::io::{Error, ErrorKind}; -use std::time::Duration; use std::collections::HashMap; +use std::io::{Error, ErrorKind}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::Duration; -use webrtc::api::APIBuilder; use webrtc::api::setting_engine::SettingEngine; +use webrtc::api::APIBuilder; use webrtc::data_channel::RTCDataChannel; +use webrtc::ice::mdns::MulticastDnsMode; use webrtc::ice_transport::ice_server::RTCIceServer; -use webrtc::peer_connection::RTCPeerConnection; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::ice::mdns::MulticastDnsMode; +use webrtc::peer_connection::RTCPeerConnection; -use crate::{ - protobuf::Message, - sodiumoxide::crypto::secretbox::Key, - ResultType, -}; - -use base64::Engine; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use base64::Engine; use bytes::{Bytes, BytesMut}; -use tokio::time::timeout; use tokio::sync::watch; use tokio::sync::Mutex; +use tokio::time::timeout; + +use crate::protobuf::Message; +use crate::sodiumoxide::crypto::secretbox::Key; +use crate::ResultType; pub struct WebRTCStream { pc: Arc, @@ -53,29 +51,36 @@ impl Clone for WebRTCStream { } impl WebRTCStream { - - pub fn get_remote_offer(endpoint: &str) -> ResultType { + #[inline] + fn get_remote_offer(endpoint: &str) -> ResultType { // Ensure the endpoint starts with the "webrtc://" prefix if !endpoint.starts_with("webrtc://") { - return Err(Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into()); + return Err( + Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into(), + ); } // Extract the Base64-encoded SDP part let encoded_sdp = &endpoint["webrtc://".len()..]; // Decode the Base64 string - let decoded_bytes = BASE64_STANDARD.decode(encoded_sdp).map_err(|_| - Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP") - )?; + let decoded_bytes = BASE64_STANDARD + .decode(encoded_sdp) + .map_err(|_| Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP"))?; Ok(String::from_utf8(decoded_bytes).map_err(|_| { - Error::new(ErrorKind::InvalidInput, "Failed to convert decoded bytes to UTF-8") + Error::new( + ErrorKind::InvalidInput, + "Failed to convert decoded bytes to UTF-8", + ) })?) } - pub fn sdp_to_endpoint(sdp: &str) -> String { + #[inline] + fn sdp_to_endpoint(sdp: &str) -> String { let encoded_sdp = BASE64_STANDARD.encode(sdp); format!("webrtc://{}", encoded_sdp) } + #[inline] async fn get_key_for_peer(pc: &Arc) -> String { if let Some(local_desc) = pc.local_description().await { if local_desc.sdp_type != webrtc::peer_connection::sdp::sdp_type::RTCSdpType::Offer { @@ -89,10 +94,7 @@ impl WebRTCStream { "".into() } - pub async fn new( - remote_endpoint: &str, - ms_timeout: u64, - ) -> ResultType { + pub async fn new(remote_endpoint: &str, ms_timeout: u64) -> ResultType { log::debug!("New webrtc stream to endpoint: {}", remote_endpoint); let remote_offer = if remote_endpoint.is_empty() { "".into() @@ -115,9 +117,7 @@ impl WebRTCStream { s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); // Create the API object - let api = APIBuilder::new() - .with_setting_engine(s) - .build(); + let api = APIBuilder::new().with_setting_engine(s).build(); // Prepare the configuration let config = RTCConfiguration { @@ -158,16 +158,25 @@ impl WebRTCStream { let pc_for_close2 = pc_for_close.clone(); Box::pin(async move { log::debug!("Peer connection state : {}", s); - if s == RTCPeerConnectionState::Disconnected { - let _ = on_connection_notify2.send(true); - log::debug!("WebRTC session closing due to disconnected"); - let _ = stream_for_close2.lock().await.close().await; - log::debug!("WebRTC session stream closed"); - } else if s == RTCPeerConnectionState::Failed || s == RTCPeerConnectionState::Closed { - let mut lock = SESSIONS.lock().await; - let key = WebRTCStream::get_key_for_peer(&pc_for_close2).await; - log::debug!("WebRTC session removing key from cache: {}", key); - lock.remove(&key); + match s { + RTCPeerConnectionState::Disconnected + | RTCPeerConnectionState::Failed + | RTCPeerConnectionState::Closed => { + let _ = on_connection_notify2.send(true); + log::debug!("WebRTC session closing due to disconnected"); + let _ = stream_for_close2.lock().await.close().await; + log::debug!("WebRTC session stream closed"); + + let mut lock = SESSIONS.lock().await; + let key = WebRTCStream::get_key_for_peer(&pc_for_close2).await; + lock.remove(&key); + log::debug!( + "WebRTC session removed key from cache: {} current len: {}", + key, + lock.len() + ); + } + _ => {} } }) })); @@ -287,7 +296,26 @@ impl WebRTCStream { } pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - self.wait_for_connect_result().await; + if self.send_timeout > 0 { + match timeout( + Duration::from_millis(self.send_timeout), + self.wait_for_connect_result(), + ) + .await + { + Ok(_) => {} + Err(_) => { + self.pc.close().await.ok(); + return Err(Error::new( + ErrorKind::TimedOut, + "WebRTC send wait for connect timeout", + ) + .into()); + } + } + } else { + self.wait_for_connect_result().await; + } let stream = self.stream.lock().await.clone(); stream.send(&bytes).await?; Ok(()) @@ -339,6 +367,5 @@ pub fn is_webrtc_endpoint(endpoint: &str) -> bool { #[cfg(test)] mod tests { #[test] - fn test_dc() { - } + fn test_dc() {} }